mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Make feature extractor pluggable
Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase. Reviewed By: davnov134 Differential Revision: D35433098 fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a
This commit is contained in:
parent
cd7b885169
commit
9ec9d057cc
@ -224,7 +224,8 @@ generic_model_args: GenericModel
|
|||||||
└-- hypernet_args: SRNRaymarchHyperNet
|
└-- hypernet_args: SRNRaymarchHyperNet
|
||||||
└-- pixel_generator_args: SRNPixelGenerator
|
└-- pixel_generator_args: SRNPixelGenerator
|
||||||
╘== IdrFeatureField
|
╘== IdrFeatureField
|
||||||
└-- image_feature_extractor_args: ResNetFeatureExtractor
|
└-- image_feature_extractor_*_args: FeatureExtractorBase
|
||||||
|
╘== ResNetFeatureExtractor
|
||||||
└-- view_sampler_args: ViewSampler
|
└-- view_sampler_args: ViewSampler
|
||||||
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
||||||
╘== IdentityFeatureAggregator
|
╘== IdentityFeatureAggregator
|
||||||
|
@ -64,7 +64,7 @@ generic_model_args:
|
|||||||
view_pooler_args:
|
view_pooler_args:
|
||||||
view_sampler_args:
|
view_sampler_args:
|
||||||
masked_sampling: false
|
masked_sampling: false
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
stages:
|
stages:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
image_feature_extractor_enabled: true
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
add_masks: true
|
add_masks: true
|
||||||
first_max_pool: true
|
first_max_pool: true
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
image_feature_extractor_enabled: true
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
add_masks: true
|
add_masks: true
|
||||||
first_max_pool: false
|
first_max_pool: false
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
image_feature_extractor_enabled: true
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
stages:
|
stages:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .feature_extractor import FeatureExtractorBase
|
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for an extractor of a set of features from images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_feat_dims(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
total number of feature dimensions of the output.
|
||||||
|
(i.e. sum_i(dim_i))
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
imgs: Optional[torch.Tensor],
|
||||||
|
masks: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[Any, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
imgs: A batch of input images of shape `(B, 3, H, W)`.
|
||||||
|
masks: A batch of input masks of shape `(B, 3, H, W)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
|
||||||
|
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -4,7 +4,6 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as Fu
|
import torch.nn.functional as Fu
|
||||||
import torchvision
|
import torchvision
|
||||||
from pytorch3d.implicitron.tools.config import Configurable
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
|
|
||||||
|
from . import FeatureExtractorBase
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406]
|
|||||||
_RESNET_STD = [0.229, 0.224, 0.225]
|
_RESNET_STD = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
|
||||||
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
@registry.register
|
||||||
|
class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||||
"""
|
"""
|
||||||
Implements an image feature extractor. Depending on the settings allows
|
Implements an image feature extractor. Depending on the settings allows
|
||||||
to extract:
|
to extract:
|
||||||
@ -142,14 +144,14 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
return (img - self._resnet_mean) / self._resnet_std
|
return (img - self._resnet_mean) / self._resnet_std
|
||||||
|
|
||||||
def get_feat_dims(self) -> int:
|
def get_feat_dims(self) -> int:
|
||||||
return (
|
# pyre-fixme[29]
|
||||||
sum(self._feat_dim.values()) # pyre-fixme[29]
|
return sum(self._feat_dim.values())
|
||||||
if len(self._feat_dim) > 0 # pyre-fixme[6]
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
|
self,
|
||||||
|
imgs: Optional[torch.Tensor],
|
||||||
|
masks: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Dict[Any, torch.Tensor]:
|
) -> Dict[Any, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -164,7 +166,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
out_feats = {}
|
out_feats = {}
|
||||||
|
|
||||||
imgs_input = imgs
|
imgs_input = imgs
|
||||||
if self.image_rescale != 1.0:
|
if self.image_rescale != 1.0 and imgs_input is not None:
|
||||||
imgs_resized = Fu.interpolate(
|
imgs_resized = Fu.interpolate(
|
||||||
imgs_input,
|
imgs_input,
|
||||||
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
|
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
|
||||||
@ -175,12 +177,13 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
imgs_resized = imgs_input
|
imgs_resized = imgs_input
|
||||||
|
|
||||||
if self.normalize_image:
|
|
||||||
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
|
||||||
else:
|
|
||||||
imgs_normed = imgs_resized
|
|
||||||
|
|
||||||
if len(self.stages) > 0:
|
if len(self.stages) > 0:
|
||||||
|
assert imgs_resized is not None
|
||||||
|
|
||||||
|
if self.normalize_image:
|
||||||
|
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
||||||
|
else:
|
||||||
|
imgs_normed = imgs_resized
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
|
||||||
# is not a function.
|
# is not a function.
|
||||||
feats = self.stem(imgs_normed)
|
feats = self.stem(imgs_normed)
|
||||||
@ -207,7 +210,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
out_feats[MASK_FEATURE_NAME] = masks
|
out_feats[MASK_FEATURE_NAME] = masks
|
||||||
|
|
||||||
if self.add_images:
|
if self.add_images:
|
||||||
assert imgs_input is not None
|
assert imgs_resized is not None
|
||||||
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
|
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
|
||||||
|
|
||||||
if self.feature_rescale != 1.0:
|
if self.feature_rescale != 1.0:
|
@ -5,6 +5,9 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The #noqa comments below are for unused imports of pluggable implementations
|
||||||
|
# which are part of implicitron. They ensure that the registry is prepopulated.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -27,6 +30,8 @@ from visdom import Visdom
|
|||||||
|
|
||||||
from .autodecoder import Autodecoder
|
from .autodecoder import Autodecoder
|
||||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||||
|
from .feature_extractor import FeatureExtractorBase
|
||||||
|
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
|
||||||
from .implicit_function.base import ImplicitFunctionBase
|
from .implicit_function.base import ImplicitFunctionBase
|
||||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||||
from .implicit_function.neural_radiance_field import ( # noqa
|
from .implicit_function.neural_radiance_field import ( # noqa
|
||||||
@ -49,7 +54,6 @@ from .renderer.lstm_renderer import LSTMRenderer # noqa
|
|||||||
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
||||||
from .renderer.ray_sampler import RaySamplerBase
|
from .renderer.ray_sampler import RaySamplerBase
|
||||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
||||||
from .resnet_feature_extractor import ResNetFeatureExtractor
|
|
||||||
from .view_pooler.view_pooler import ViewPooler
|
from .view_pooler.view_pooler import ViewPooler
|
||||||
|
|
||||||
|
|
||||||
@ -139,9 +143,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
||||||
splatting onto an image grid. Default: False.
|
splatting onto an image grid. Default: False.
|
||||||
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
|
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
|
||||||
view_pool: If True, features are sampled from the source image(s)
|
|
||||||
at the projected 2d locations of the sampled 3d ray points from the target
|
|
||||||
view(s), i.e. this activates step (3) above.
|
|
||||||
num_passes: The specified implicit_function is initialized num_passes
|
num_passes: The specified implicit_function is initialized num_passes
|
||||||
times and run sequentially.
|
times and run sequentially.
|
||||||
chunk_size_grid: The total number of points which can be rendered
|
chunk_size_grid: The total number of points which can be rendered
|
||||||
@ -169,10 +170,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
registry.
|
registry.
|
||||||
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
||||||
generate the images from the target view(s).
|
generate the images from the target view(s).
|
||||||
image_feature_extractor_enabled: If `True`, constructs and enables
|
image_feature_extractor_class_type: If a str, constructs and enables
|
||||||
the `image_feature_extractor` object.
|
the `image_feature_extractor` object of this type. Or None if not needed.
|
||||||
image_feature_extractor: A module for extrating features from an input image.
|
image_feature_extractor: A module for extrating features from an input image.
|
||||||
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
|
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
|
||||||
|
This means features are sampled from the source image(s)
|
||||||
|
at the projected 2d locations of the sampled 3d ray points from the target
|
||||||
|
view(s), i.e. this activates step (3) above.
|
||||||
view_pooler: An instance of ViewPooler which is used for sampling of
|
view_pooler: An instance of ViewPooler which is used for sampling of
|
||||||
image-based features at the 2D projections of a set
|
image-based features at the 2D projections of a set
|
||||||
of 3D points and aggregating the sampled features.
|
of 3D points and aggregating the sampled features.
|
||||||
@ -215,8 +219,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
renderer: BaseRenderer
|
renderer: BaseRenderer
|
||||||
|
|
||||||
# ---- image feature extractor settings
|
# ---- image feature extractor settings
|
||||||
image_feature_extractor_enabled: bool = False
|
# (This is only created if view_pooler is enabled)
|
||||||
image_feature_extractor: Optional[ResNetFeatureExtractor]
|
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||||
|
image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
|
||||||
# ---- view pooler settings
|
# ---- view pooler settings
|
||||||
view_pooler_enabled: bool = False
|
view_pooler_enabled: bool = False
|
||||||
view_pooler: Optional[ViewPooler]
|
view_pooler: Optional[ViewPooler]
|
||||||
@ -266,6 +271,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.view_metrics = ViewMetrics()
|
self.view_metrics = ViewMetrics()
|
||||||
|
|
||||||
|
if self.view_pooler_enabled:
|
||||||
|
if self.image_feature_extractor_class_type is None:
|
||||||
|
raise ValueError(
|
||||||
|
"image_feature_extractor must be present for view pooling."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.image_feature_extractor_class_type = None
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
self._implicit_functions = self._construct_implicit_functions()
|
self._implicit_functions = self._construct_implicit_functions()
|
||||||
@ -349,20 +361,16 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
# custom_args hold additional arguments to the implicit function.
|
# custom_args hold additional arguments to the implicit function.
|
||||||
custom_args = {}
|
custom_args = {}
|
||||||
|
|
||||||
if self.image_feature_extractor_enabled:
|
if self.image_feature_extractor is not None:
|
||||||
# (2) Extract features for the image
|
# (2) Extract features for the image
|
||||||
img_feats = self.image_feature_extractor( # pyre-fixme[29]
|
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
|
||||||
image_rgb, fg_probability
|
else:
|
||||||
)
|
img_feats = None
|
||||||
|
|
||||||
if self.view_pooler_enabled:
|
if self.view_pooler_enabled:
|
||||||
if sequence_name is None:
|
if sequence_name is None:
|
||||||
raise ValueError("sequence_name must be provided for view pooling")
|
raise ValueError("sequence_name must be provided for view pooling")
|
||||||
if not self.image_feature_extractor_enabled:
|
assert img_feats is not None
|
||||||
raise ValueError(
|
|
||||||
"image_feature_extractor has to be enabled for for view pooling"
|
|
||||||
+ " (I.e. set self.image_feature_extractor_enabled=True)."
|
|
||||||
)
|
|
||||||
|
|
||||||
# (3-4) Sample features and masks at the ray points.
|
# (3-4) Sample features and masks at the ray points.
|
||||||
# Aggregate features from multiple views.
|
# Aggregate features from multiple views.
|
||||||
@ -555,13 +563,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_viewpooled_feature_dim(self):
|
def _get_viewpooled_feature_dim(self) -> int:
|
||||||
return (
|
if self.view_pooler is None:
|
||||||
self.view_pooler.get_aggregated_feature_dim(
|
return 0
|
||||||
self.image_feature_extractor.get_feat_dims()
|
assert self.image_feature_extractor is not None
|
||||||
)
|
return self.view_pooler.get_aggregated_feature_dim(
|
||||||
if self.view_pooler_enabled
|
self.image_feature_extractor.get_feat_dims()
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_raysampler(self):
|
def create_raysampler(self):
|
||||||
@ -617,11 +624,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
that image_feature_extractor is enabled when view_pooler is enabled.
|
that image_feature_extractor is enabled when view_pooler is enabled.
|
||||||
"""
|
"""
|
||||||
if self.view_pooler_enabled:
|
if self.view_pooler_enabled:
|
||||||
if not self.image_feature_extractor_enabled:
|
|
||||||
raise ValueError(
|
|
||||||
"image_feature_extractor has to be enabled for view pooling"
|
|
||||||
+ " (I.e. set self.image_feature_extractor_enabled=True)."
|
|
||||||
)
|
|
||||||
self.view_pooler = ViewPooler(**self.view_pooler_args)
|
self.view_pooler = ViewPooler(**self.view_pooler_args)
|
||||||
else:
|
else:
|
||||||
self.view_pooler = None
|
self.view_pooler = None
|
||||||
|
@ -17,7 +17,7 @@ sampling_mode_training: mask_sample
|
|||||||
sampling_mode_evaluation: full_grid
|
sampling_mode_evaluation: full_grid
|
||||||
raysampler_class_type: AdaptiveRaySampler
|
raysampler_class_type: AdaptiveRaySampler
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
image_feature_extractor_enabled: true
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
implicit_function_class_type: IdrFeatureField
|
implicit_function_class_type: IdrFeatureField
|
||||||
loss_weights:
|
loss_weights:
|
||||||
@ -73,7 +73,7 @@ renderer_LSTMRenderer_args:
|
|||||||
hidden_size: 16
|
hidden_size: 16
|
||||||
n_feature_channels: 256
|
n_feature_channels: 256
|
||||||
verbose: false
|
verbose: false
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
name: resnet34
|
name: resnet34
|
||||||
pretrained: true
|
pretrained: true
|
||||||
stages:
|
stages:
|
||||||
|
@ -9,6 +9,9 @@ import unittest
|
|||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||||
|
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
||||||
|
ResNetFeatureExtractor,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||||
IdrFeatureField,
|
IdrFeatureField,
|
||||||
@ -63,7 +66,6 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
provide_resnet34()
|
provide_resnet34()
|
||||||
args = get_default_args(GenericModel)
|
args = get_default_args(GenericModel)
|
||||||
args.view_pooler_enabled = True
|
args.view_pooler_enabled = True
|
||||||
args.image_feature_extractor_enabled = True
|
|
||||||
args.view_pooler_args.feature_aggregator_class_type = (
|
args.view_pooler_args.feature_aggregator_class_type = (
|
||||||
"AngleWeightedIdentityFeatureAggregator"
|
"AngleWeightedIdentityFeatureAggregator"
|
||||||
)
|
)
|
||||||
@ -77,9 +79,13 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||||
|
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||||
|
|
||||||
instance_args = OmegaConf.structured(gm)
|
instance_args = OmegaConf.structured(gm)
|
||||||
|
if DEBUG:
|
||||||
|
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||||
|
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml)
|
||||||
remove_unused_components(instance_args)
|
remove_unused_components(instance_args)
|
||||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
|
@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
# Simple test of a forward and backward pass of the default GenericModel.
|
# Simple test of a forward and backward pass of the default GenericModel.
|
||||||
device = torch.device("cuda:1")
|
device = torch.device("cuda:1")
|
||||||
expand_args_fields(GenericModel)
|
expand_args_fields(GenericModel)
|
||||||
model = GenericModel()
|
model = GenericModel(render_image_height=80, render_image_width=80)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
self._one_model_test(model, device)
|
self._one_model_test(model, device)
|
||||||
|
|
||||||
@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertGreater(train_preds["objective"].item(), 0)
|
self.assertGreater(train_preds["objective"].item(), 0)
|
||||||
|
|
||||||
|
def test_viewpool(self):
|
||||||
|
device = torch.device("cuda:1")
|
||||||
|
args = get_default_args(GenericModel)
|
||||||
|
args.view_pooler_enabled = True
|
||||||
|
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
|
||||||
|
model = GenericModel(**args)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
n_train_cameras = 2
|
||||||
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||||
|
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
||||||
|
|
||||||
|
defaulted_args = {
|
||||||
|
"fg_probability": None,
|
||||||
|
"depth_map": None,
|
||||||
|
"mask_crop": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
target_image_rgb = torch.rand(
|
||||||
|
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
train_preds = model(
|
||||||
|
camera=cameras,
|
||||||
|
evaluation_mode=EvaluationMode.TRAINING,
|
||||||
|
image_rgb=target_image_rgb,
|
||||||
|
sequence_name=["a"] * n_train_cameras,
|
||||||
|
**defaulted_args,
|
||||||
|
)
|
||||||
|
self.assertGreater(train_preds["objective"].item(), 0)
|
||||||
|
|
||||||
|
|
||||||
def _random_input_tensor(
|
def _random_input_tensor(
|
||||||
N: int,
|
N: int,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user