ViewPooler class

Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator.

Reviewed By: shapovalov

Differential Revision: D35852367

fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
This commit is contained in:
David Novotny 2022-05-12 12:50:03 -07:00 committed by Facebook GitHub Bot
parent bef959c755
commit 47d06c8924
26 changed files with 304 additions and 110 deletions

View File

@ -63,8 +63,9 @@ generic_model_args:
n_pts_per_ray_fine_evaluation: 64
append_coarse_samples_to_fine: true
density_noise_std_train: 1.0
view_sampler_args:
masked_sampling: false
view_pooler_args:
view_sampler_args:
masked_sampling: false
image_feature_extractor_args:
stages:
- 1

View File

@ -1,4 +1,5 @@
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
add_images: true
add_masks: true

View File

@ -1,4 +1,5 @@
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
add_images: true
add_masks: true

View File

@ -1,4 +1,5 @@
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
stages:
- 1
@ -11,6 +12,7 @@ generic_model_args:
name: resnet34
normalize_image: true
pretrained: true
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
reduction_functions:
- AVG
view_pooler_args:
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
reduction_functions:
- AVG

View File

@ -11,7 +11,6 @@ generic_model_args:
num_passes: 1
output_rasterized_mc: true
sampling_mode_training: mask_sample
view_pool: false
sequence_autodecoder_args:
n_instances: 20000
init_scale: 1.0

View File

@ -3,7 +3,7 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
view_pooler_enabled: false
sequence_autodecoder_args:
n_instances: 20000
encoding_dim: 256

View File

@ -5,6 +5,6 @@ defaults:
clip_grad: 1.0
generic_model_args:
chunk_size_grid: 16000
view_pool: true
view_pooler_enabled: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850

View File

@ -4,7 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
@ -13,4 +12,6 @@ generic_model_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: IdentityFeatureAggregator
view_pooler_enabled: true
view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@ -4,7 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
@ -13,4 +12,6 @@ generic_model_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
view_pooler_enabled: true
view_pooler_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator

View File

@ -3,7 +3,7 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
view_pooler_enabled: false
n_train_target_views: -1
num_passes: 1
loss_weights:

View File

@ -4,7 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 32000
view_pool: true
num_passes: 1
n_train_target_views: -1
loss_weights:
@ -25,6 +24,7 @@ generic_model_args:
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
view_pooler_enabled: true
solver_args:
breed: adam
lr: 5.0e-05

View File

@ -9,7 +9,7 @@ generic_model_args:
loss_eikonal: 0.1
chunk_size_grid: 65536
num_passes: 1
view_pool: false
view_pooler_enabled: false
implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6
bias: 0.6

View File

@ -4,6 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
view_pooler_enabled: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850

View File

@ -4,7 +4,7 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
view_pooler_enabled: true
implicit_function_class_type: NeRFormerImplicitFunction
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
@ -13,4 +13,5 @@ generic_model_args:
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
feature_aggregator_class_type: IdentityFeatureAggregator
view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@ -4,7 +4,7 @@ defaults:
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: false
view_pooler_enabled: false
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0

View File

@ -5,7 +5,7 @@ defaults:
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: true
view_pooler_enabled: true
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0

View File

@ -49,8 +49,7 @@ from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySampler
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooling.feature_aggregation import FeatureAggregatorBase
from .view_pooling.view_sampling import ViewSampler
from .view_pooler.view_pooler import ViewPooler
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
@ -167,16 +166,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s).
image_feature_extractor_enabled: If `True`, constructs and enables
the `image_feature_extractor` object.
image_feature_extractor: A module for extrating features from an input image.
view_sampler: An instance of ViewSampler which is used for sampling of
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
view_pooler: An instance of ViewPooler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
of 3D points and aggregating the sampled features.
implicit_function_class_type: The type of implicit function to use which
is available in the global registry.
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
@ -195,7 +191,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
mask_threshold: float = 0.5
output_rasterized_mc: bool = False
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
view_pool: bool = False
num_passes: int = 1
chunk_size_grid: int = 4096
render_features_dimensions: int = 3
@ -215,13 +210,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
renderer: BaseRenderer
# ---- view sampling settings - used if view_pool=True
# (This is only created if view_pool is False)
image_feature_extractor: ResNetFeatureExtractor
view_sampler: ViewSampler
# ---- ---- view sampling feature aggregator settings
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
feature_aggregator: FeatureAggregatorBase
# ---- image feature extractor settings
image_feature_extractor_enabled: bool = False
image_feature_extractor: Optional[ResNetFeatureExtractor]
# ---- view pooler settings
view_pooler_enabled: bool = False
view_pooler: Optional[ViewPooler]
# ---- implicit function settings
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
@ -356,32 +350,34 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# custom_args hold additional arguments to the implicit function.
custom_args = {}
if self.view_pool:
if self.image_feature_extractor_enabled:
# (2) Extract features for the image
img_feats = self.image_feature_extractor( # pyre-fixme[29]
image_rgb, fg_probability
)
if self.view_pooler_enabled:
if sequence_name is None:
raise ValueError("sequence_name must be provided for view pooling")
# (2) Extract features for the image
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
if not self.image_feature_extractor_enabled:
raise ValueError(
"image_feature_extractor has to be enabled for for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
# (3) Sample features and masks at the ray points
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731
pts=pts,
seq_id_pts=sequence_name[:n_targets],
camera=camera,
seq_id_camera=sequence_name,
feats=img_feats,
masks=mask_crop,
) # returns feats_sampled, masks_sampled
# (3-4) Sample features and masks at the ray points.
# Aggregate features from multiple views.
def curried_viewpooler(pts):
return self.view_pooler(
pts=pts,
seq_id_pts=sequence_name[:n_targets],
camera=camera,
seq_id_camera=sequence_name,
feats=img_feats,
masks=mask_crop,
)
# (4) Aggregate features from multiple views
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731
*curried_view_sampler(pts=pts),
pts=pts,
camera=camera,
) # TODO: do we need to pass a callback rather than compute here?
# precomputing will be faster for 2 passes
# -> but this is important for non-nerf
custom_args["fun_viewpool"] = curried_view_pool
custom_args["fun_viewpool"] = curried_viewpooler
global_code = None
if self.sequence_autodecoder.n_instances > 0:
@ -562,10 +558,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
def _get_viewpooled_feature_dim(self):
return (
self.feature_aggregator.get_aggregated_feature_dim(
self.view_pooler.get_aggregated_feature_dim(
self.image_feature_extractor.get_feat_dims()
)
if self.view_pool
if self.view_pooler_enabled
else 0
)
@ -583,15 +579,20 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
"object_bounding_sphere"
] = self.raysampler_args["scene_extent"]
def create_image_feature_extractor(self):
def create_view_pooler(self):
"""
Custom creation function called by run_auto_creation so that the
image_feature_extractor is not created if it is not be needed.
Custom creation function called by run_auto_creation checking
that image_feature_extractor is enabled when view_pooler is enabled.
"""
if self.view_pool:
self.image_feature_extractor = ResNetFeatureExtractor(
**self.image_feature_extractor_args
)
if self.view_pooler_enabled:
if not self.image_feature_extractor_enabled:
raise ValueError(
"image_feature_extractor has to be enabled for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
self.view_pooler = ViewPooler(**self.view_pooler_args)
else:
self.view_pooler = None
def create_implicit_function(self) -> None:
"""
@ -652,10 +653,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
)
if implicit_function_type.requires_pooling_without_aggregation():
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions")
if not self.view_pool or has_aggregation:
if self.view_pooler_enabled and self.view_pooler.has_aggregation():
raise ValueError(
"Chosen implicit function requires view pooling without aggregation."
"The chosen implicit function requires view pooling without aggregation."
)
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
config = getattr(self, config_name, None)

View File

@ -141,11 +141,12 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self, size_dict: bool = False):
if size_dict:
return copy.deepcopy(self._feat_dim)
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.values)[[Na...
return sum(self._feat_dim.values())
def get_feat_dims(self) -> int:
return (
sum(self._feat_dim.values()) # pyre-fixme[29]
if len(self._feat_dim) > 0 # pyre-fixme[6]
else 0
)
def forward(
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None

View File

@ -10,7 +10,7 @@ from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.models.view_pooling.view_sampling import (
from pytorch3d.implicitron.models.view_pooler.view_sampler import (
cameras_points_cartesian_product,
)
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase):
"""
raise NotImplementedError()
@abstractmethod
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
"""
Returns the final dimensionality of the output aggregated features.
Args:
feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding
to the `feats_sampled` argument of `forward`,
or an `int` representing the sum of dimensionalities of each `t_i`.
Returns:
aggregated_feature_dim: The final dimensionality of the output
aggregated features.
"""
raise NotImplementedError()
def has_aggregation(self) -> bool:
"""
Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1.
Returns:
has_aggregation: `True` if `reduce_dim==1`, else `False`.
"""
return hasattr(self, "reduction_functions")
@registry.register
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
def forward(
self,
@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(
feats_or_feats_dim, self.reduction_functions
)
def forward(
self,
@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(
feats_or_feats_dim, self.reduction_functions
)
def forward(
self,
@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
def forward(
self,

View File

@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Union
import torch
from pytorch3d.implicitron.tools.config import Configurable, run_auto_creation
from pytorch3d.renderer.cameras import CamerasBase
from .feature_aggregator import FeatureAggregatorBase
from .view_sampler import ViewSampler
# pyre-ignore: 13
class ViewPooler(Configurable, torch.nn.Module):
"""
Implements sampling of image-based features at the 2d projections of a set
of 3D points, and a subsequent aggregation of the resulting set of features
per-point.
Args:
view_sampler: An instance of ViewSampler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
"""
view_sampler: ViewSampler
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
feature_aggregator: FeatureAggregatorBase
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
"""
Returns the final dimensionality of the output aggregated features.
Args:
feats: Either a `dict` of sampled features `{f_i: t_i}` corresponding
to the `feats_sampled` argument of `feature_aggregator,forward`,
or an `int` representing the sum of dimensionalities of each `t_i`.
Returns:
aggregated_feature_dim: The final dimensionality of the output
aggregated features.
"""
return self.feature_aggregator.get_aggregated_feature_dim(feats)
def has_aggregation(self):
"""
Specifies whether the `feature_aggregator` reduces the output `reduce_dim`
dimension to 1.
Returns:
has_aggregation: `True` if `reduce_dim==1`, else `False`.
"""
return self.feature_aggregator.has_aggregation()
def forward(
self,
*, # force kw args
pts: torch.Tensor,
seq_id_pts: Union[List[int], List[str], torch.LongTensor],
camera: CamerasBase,
seq_id_camera: Union[List[int], List[str], torch.LongTensor],
feats: Dict[str, torch.Tensor],
masks: Optional[torch.Tensor],
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Project each point cloud from a batch of point clouds to corresponding
input cameras, sample features at the 2D projection locations in a batch
of source images, and aggregate the pointwise sampled features.
Args:
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
from which `pts` were extracted, or a list of string names.
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
corresponding to cameras in `camera`, or a list of string names.
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
masks: `[n_cameras x 1 x H x W]`, define valid image regions
for sampling `feats`.
Returns:
feats_aggregated: If `feature_aggregator.concatenate_output==True`, a tensor
of shape `(pts_batch, reduce_dim, n_pts, sum(dim_1, ... dim_N))`
containing the aggregated features. `reduce_dim` depends on
the specific feature aggregator implementation and typically
equals 1 or `n_cameras`.
If `feature_aggregator.concatenate_output==False`, the aggregator
does not concatenate the aggregated features and returns a dictionary
of per-feature aggregations `{f_i: t_i_aggregated}` instead.
Each `t_i_aggregated` is of shape
`(pts_batch, reduce_dim, n_pts, aggr_dim_i)`.
"""
# (1) Sample features and masks at the ray points
sampled_feats, sampled_masks = self.view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
# (2) Aggregate features from multiple views
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
feats_aggregated = self.feature_aggregator( # noqa: E731
sampled_feats,
sampled_masks,
pts=pts,
camera=camera,
) # TODO: do we need to pass a callback rather than compute here?
return feats_aggregated

View File

@ -8,7 +8,6 @@ bg_color:
- 0.0
- 0.0
- 0.0
view_pool: false
num_passes: 1
chunk_size_grid: 4096
render_features_dimensions: 3
@ -17,7 +16,8 @@ n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
renderer_class_type: LSTMRenderer
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
image_feature_extractor_enabled: true
view_pooler_enabled: true
implicit_function_class_type: IdrFeatureField
loss_weights:
loss_rgb_mse: 1.0
@ -91,15 +91,17 @@ image_feature_extractor_args:
add_images: true
global_average_pool: false
feature_rescale: 1.0
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
view_pooler_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3

View File

@ -20,9 +20,8 @@ from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import (
AngleWeightedIdentityFeatureAggregator,
AngleWeightedReductionFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
@ -32,7 +31,10 @@ from pytorch3d.implicitron.tools.config import (
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
from .common_resources import provide_lpips_vgg
else:
from common_resources import provide_lpips_vgg # noqa
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
@ -46,28 +48,33 @@ class TestGenericModel(unittest.TestCase):
self.maxDiff = None
def test_create_gm(self):
provide_lpips_vgg()
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertFalse(hasattr(gm, "image_feature_extractor"))
self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor)
def test_create_gm_overrides(self):
def _test_create_gm_overrides(self):
provide_lpips_vgg()
args = get_default_args(GenericModel)
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
args.view_pooler_enabled = True
args.image_feature_extractor_enabled = True
args.view_pooler_args.feature_aggregator_class_type = (
"AngleWeightedIdentityFeatureAggregator"
)
args.implicit_function_class_type = "IdrFeatureField"
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
gm.view_pooler.feature_aggregator,
AngleWeightedIdentityFeatureAggregator,
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)

View File

@ -56,7 +56,12 @@ class TestGenericModel(unittest.TestCase):
cfg = _load_model_config_from_yaml(str(config_file))
model = GenericModel(**cfg)
model.to(device)
self._one_model_test(model, device, eval_test=True)
self._one_model_test(
model,
device,
eval_test=True,
bw_test=True,
)
def _one_model_test(
self,
@ -64,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
device,
n_train_cameras: int = 5,
eval_test: bool = True,
bw_test: bool = True,
):
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
@ -86,8 +92,12 @@ class TestGenericModel(unittest.TestCase):
**random_args,
evaluation_mode=EvaluationMode.TRAINING,
)
self.assertGreater(train_preds["objective"].item(), 0)
train_preds["objective"].backward()
self.assertTrue(
train_preds["objective"].isfinite().item()
) # check finiteness of the objective
if bw_test:
train_preds["objective"].backward()
if eval_test:
model.eval()

View File

@ -9,7 +9,7 @@ import unittest
import pytorch3d as pt3d
import torch
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
from pytorch3d.implicitron.models.view_pooler.view_sampler import ViewSampler
from pytorch3d.implicitron.tools.config import expand_args_fields