mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Make feature extractor pluggable
Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase. Reviewed By: davnov134 Differential Revision: D35433098 fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a
This commit is contained in:
		
							parent
							
								
									cd7b885169
								
							
						
					
					
						commit
						9ec9d057cc
					
				@ -224,7 +224,8 @@ generic_model_args: GenericModel
 | 
			
		||||
        └-- hypernet_args: SRNRaymarchHyperNet
 | 
			
		||||
        └-- pixel_generator_args: SRNPixelGenerator
 | 
			
		||||
    ╘== IdrFeatureField
 | 
			
		||||
└-- image_feature_extractor_args: ResNetFeatureExtractor
 | 
			
		||||
└-- image_feature_extractor_*_args: FeatureExtractorBase
 | 
			
		||||
    ╘== ResNetFeatureExtractor
 | 
			
		||||
└-- view_sampler_args: ViewSampler
 | 
			
		||||
└-- feature_aggregator_*_args: FeatureAggregatorBase
 | 
			
		||||
    ╘== IdentityFeatureAggregator
 | 
			
		||||
 | 
			
		||||
@ -64,7 +64,7 @@ generic_model_args:
 | 
			
		||||
  view_pooler_args:
 | 
			
		||||
    view_sampler_args:
 | 
			
		||||
      masked_sampling: false
 | 
			
		||||
  image_feature_extractor_args:
 | 
			
		||||
  image_feature_extractor_ResNetFeatureExtractor_args:
 | 
			
		||||
    stages:
 | 
			
		||||
    - 1
 | 
			
		||||
    - 2
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  image_feature_extractor_enabled: true
 | 
			
		||||
  image_feature_extractor_args:
 | 
			
		||||
  image_feature_extractor_class_type: ResNetFeatureExtractor
 | 
			
		||||
  image_feature_extractor_ResNetFeatureExtractor_args:
 | 
			
		||||
    add_images: true
 | 
			
		||||
    add_masks: true
 | 
			
		||||
    first_max_pool: true
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  image_feature_extractor_enabled: true
 | 
			
		||||
  image_feature_extractor_args:
 | 
			
		||||
  image_feature_extractor_class_type: ResNetFeatureExtractor
 | 
			
		||||
  image_feature_extractor_ResNetFeatureExtractor_args:
 | 
			
		||||
    add_images: true
 | 
			
		||||
    add_masks: true
 | 
			
		||||
    first_max_pool: false
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  image_feature_extractor_enabled: true
 | 
			
		||||
  image_feature_extractor_args:
 | 
			
		||||
  image_feature_extractor_class_type: ResNetFeatureExtractor
 | 
			
		||||
  image_feature_extractor_ResNetFeatureExtractor_args:
 | 
			
		||||
    stages:
 | 
			
		||||
    - 1
 | 
			
		||||
    - 2
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,7 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from .feature_extractor import FeatureExtractorBase
 | 
			
		||||
@ -0,0 +1,44 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Base class for an extractor of a set of features from images.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_feat_dims(self) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        Returns:
 | 
			
		||||
            total number of feature dimensions of the output.
 | 
			
		||||
            (i.e. sum_i(dim_i))
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        imgs: Optional[torch.Tensor],
 | 
			
		||||
        masks: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Dict[Any, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            imgs: A batch of input images of shape `(B, 3, H, W)`.
 | 
			
		||||
            masks: A batch of input masks of shape `(B, 3, H, W)`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
 | 
			
		||||
                and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
@ -4,7 +4,6 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
import logging
 | 
			
		||||
import math
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as Fu
 | 
			
		||||
import torchvision
 | 
			
		||||
from pytorch3d.implicitron.tools.config import Configurable
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry
 | 
			
		||||
 | 
			
		||||
from . import FeatureExtractorBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406]
 | 
			
		||||
_RESNET_STD = [0.229, 0.224, 0.225]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
 | 
			
		||||
@registry.register
 | 
			
		||||
class ResNetFeatureExtractor(FeatureExtractorBase):
 | 
			
		||||
    """
 | 
			
		||||
    Implements an image feature extractor. Depending on the settings allows
 | 
			
		||||
    to extract:
 | 
			
		||||
@ -142,14 +144,14 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
 | 
			
		||||
        return (img - self._resnet_mean) / self._resnet_std
 | 
			
		||||
 | 
			
		||||
    def get_feat_dims(self) -> int:
 | 
			
		||||
        return (
 | 
			
		||||
            sum(self._feat_dim.values())  # pyre-fixme[29]
 | 
			
		||||
            if len(self._feat_dim) > 0  # pyre-fixme[6]
 | 
			
		||||
            else 0
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-fixme[29]
 | 
			
		||||
        return sum(self._feat_dim.values())
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
 | 
			
		||||
        self,
 | 
			
		||||
        imgs: Optional[torch.Tensor],
 | 
			
		||||
        masks: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[Any, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
@ -164,7 +166,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
 | 
			
		||||
        out_feats = {}
 | 
			
		||||
 | 
			
		||||
        imgs_input = imgs
 | 
			
		||||
        if self.image_rescale != 1.0:
 | 
			
		||||
        if self.image_rescale != 1.0 and imgs_input is not None:
 | 
			
		||||
            imgs_resized = Fu.interpolate(
 | 
			
		||||
                imgs_input,
 | 
			
		||||
                # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
 | 
			
		||||
@ -175,12 +177,13 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            imgs_resized = imgs_input
 | 
			
		||||
 | 
			
		||||
        if self.normalize_image:
 | 
			
		||||
            imgs_normed = self._resnet_normalize_image(imgs_resized)
 | 
			
		||||
        else:
 | 
			
		||||
            imgs_normed = imgs_resized
 | 
			
		||||
 | 
			
		||||
        if len(self.stages) > 0:
 | 
			
		||||
            assert imgs_resized is not None
 | 
			
		||||
 | 
			
		||||
            if self.normalize_image:
 | 
			
		||||
                imgs_normed = self._resnet_normalize_image(imgs_resized)
 | 
			
		||||
            else:
 | 
			
		||||
                imgs_normed = imgs_resized
 | 
			
		||||
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
 | 
			
		||||
            #  is not a function.
 | 
			
		||||
            feats = self.stem(imgs_normed)
 | 
			
		||||
@ -207,7 +210,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
 | 
			
		||||
            out_feats[MASK_FEATURE_NAME] = masks
 | 
			
		||||
 | 
			
		||||
        if self.add_images:
 | 
			
		||||
            assert imgs_input is not None
 | 
			
		||||
            assert imgs_resized is not None
 | 
			
		||||
            out_feats[IMAGE_FEATURE_NAME] = imgs_resized
 | 
			
		||||
 | 
			
		||||
        if self.feature_rescale != 1.0:
 | 
			
		||||
@ -5,6 +5,9 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Note: The #noqa comments below are for unused imports of pluggable implementations
 | 
			
		||||
# which are part of implicitron. They ensure that the registry is prepopulated.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
@ -27,6 +30,8 @@ from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
from .autodecoder import Autodecoder
 | 
			
		||||
from .base_model import ImplicitronModelBase, ImplicitronRender
 | 
			
		||||
from .feature_extractor import FeatureExtractorBase
 | 
			
		||||
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor  # noqa
 | 
			
		||||
from .implicit_function.base import ImplicitFunctionBase
 | 
			
		||||
from .implicit_function.idr_feature_field import IdrFeatureField  # noqa
 | 
			
		||||
from .implicit_function.neural_radiance_field import (  # noqa
 | 
			
		||||
@ -49,7 +54,6 @@ from .renderer.lstm_renderer import LSTMRenderer  # noqa
 | 
			
		||||
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer  # noqa
 | 
			
		||||
from .renderer.ray_sampler import RaySamplerBase
 | 
			
		||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer  # noqa
 | 
			
		||||
from .resnet_feature_extractor import ResNetFeatureExtractor
 | 
			
		||||
from .view_pooler.view_pooler import ViewPooler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -139,9 +143,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
 | 
			
		||||
            splatting onto an image grid. Default: False.
 | 
			
		||||
        bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
 | 
			
		||||
        view_pool: If True, features are sampled from the source image(s)
 | 
			
		||||
            at the projected 2d locations of the sampled 3d ray points from the target
 | 
			
		||||
            view(s), i.e. this activates step (3) above.
 | 
			
		||||
        num_passes: The specified implicit_function is initialized num_passes
 | 
			
		||||
            times and run sequentially.
 | 
			
		||||
        chunk_size_grid: The total number of points which can be rendered
 | 
			
		||||
@ -169,10 +170,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            registry.
 | 
			
		||||
        renderer: A renderer class which inherits from BaseRenderer. This is used to
 | 
			
		||||
            generate the images from the target view(s).
 | 
			
		||||
        image_feature_extractor_enabled: If `True`, constructs and enables
 | 
			
		||||
            the `image_feature_extractor` object.
 | 
			
		||||
        image_feature_extractor_class_type: If a str, constructs and enables
 | 
			
		||||
            the `image_feature_extractor` object of this type. Or None if not needed.
 | 
			
		||||
        image_feature_extractor: A module for extrating features from an input image.
 | 
			
		||||
        view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
 | 
			
		||||
            This means features are sampled from the source image(s)
 | 
			
		||||
            at the projected 2d locations of the sampled 3d ray points from the target
 | 
			
		||||
            view(s), i.e. this activates step (3) above.
 | 
			
		||||
        view_pooler: An instance of ViewPooler which is used for sampling of
 | 
			
		||||
            image-based features at the 2D projections of a set
 | 
			
		||||
            of 3D points and aggregating the sampled features.
 | 
			
		||||
@ -215,8 +219,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
    renderer: BaseRenderer
 | 
			
		||||
 | 
			
		||||
    # ---- image feature extractor settings
 | 
			
		||||
    image_feature_extractor_enabled: bool = False
 | 
			
		||||
    image_feature_extractor: Optional[ResNetFeatureExtractor]
 | 
			
		||||
    # (This is only created if view_pooler is enabled)
 | 
			
		||||
    image_feature_extractor: Optional[FeatureExtractorBase]
 | 
			
		||||
    image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
 | 
			
		||||
    # ---- view pooler settings
 | 
			
		||||
    view_pooler_enabled: bool = False
 | 
			
		||||
    view_pooler: Optional[ViewPooler]
 | 
			
		||||
@ -266,6 +271,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.view_metrics = ViewMetrics()
 | 
			
		||||
 | 
			
		||||
        if self.view_pooler_enabled:
 | 
			
		||||
            if self.image_feature_extractor_class_type is None:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "image_feature_extractor must be present for view pooling."
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            self.image_feature_extractor_class_type = None
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
        self._implicit_functions = self._construct_implicit_functions()
 | 
			
		||||
@ -349,20 +361,16 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        # custom_args hold additional arguments to the implicit function.
 | 
			
		||||
        custom_args = {}
 | 
			
		||||
 | 
			
		||||
        if self.image_feature_extractor_enabled:
 | 
			
		||||
        if self.image_feature_extractor is not None:
 | 
			
		||||
            # (2) Extract features for the image
 | 
			
		||||
            img_feats = self.image_feature_extractor(  # pyre-fixme[29]
 | 
			
		||||
                image_rgb, fg_probability
 | 
			
		||||
            )
 | 
			
		||||
            img_feats = self.image_feature_extractor(image_rgb, fg_probability)
 | 
			
		||||
        else:
 | 
			
		||||
            img_feats = None
 | 
			
		||||
 | 
			
		||||
        if self.view_pooler_enabled:
 | 
			
		||||
            if sequence_name is None:
 | 
			
		||||
                raise ValueError("sequence_name must be provided for view pooling")
 | 
			
		||||
            if not self.image_feature_extractor_enabled:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "image_feature_extractor has to be enabled for for view pooling"
 | 
			
		||||
                    + " (I.e. set self.image_feature_extractor_enabled=True)."
 | 
			
		||||
                )
 | 
			
		||||
            assert img_feats is not None
 | 
			
		||||
 | 
			
		||||
            # (3-4) Sample features and masks at the ray points.
 | 
			
		||||
            #       Aggregate features from multiple views.
 | 
			
		||||
@ -555,13 +563,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
                **kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _get_viewpooled_feature_dim(self):
 | 
			
		||||
        return (
 | 
			
		||||
            self.view_pooler.get_aggregated_feature_dim(
 | 
			
		||||
                self.image_feature_extractor.get_feat_dims()
 | 
			
		||||
            )
 | 
			
		||||
            if self.view_pooler_enabled
 | 
			
		||||
            else 0
 | 
			
		||||
    def _get_viewpooled_feature_dim(self) -> int:
 | 
			
		||||
        if self.view_pooler is None:
 | 
			
		||||
            return 0
 | 
			
		||||
        assert self.image_feature_extractor is not None
 | 
			
		||||
        return self.view_pooler.get_aggregated_feature_dim(
 | 
			
		||||
            self.image_feature_extractor.get_feat_dims()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def create_raysampler(self):
 | 
			
		||||
@ -617,11 +624,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        that image_feature_extractor is enabled when view_pooler is enabled.
 | 
			
		||||
        """
 | 
			
		||||
        if self.view_pooler_enabled:
 | 
			
		||||
            if not self.image_feature_extractor_enabled:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "image_feature_extractor has to be enabled for view pooling"
 | 
			
		||||
                    + " (I.e. set self.image_feature_extractor_enabled=True)."
 | 
			
		||||
                )
 | 
			
		||||
            self.view_pooler = ViewPooler(**self.view_pooler_args)
 | 
			
		||||
        else:
 | 
			
		||||
            self.view_pooler = None
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ sampling_mode_training: mask_sample
 | 
			
		||||
sampling_mode_evaluation: full_grid
 | 
			
		||||
raysampler_class_type: AdaptiveRaySampler
 | 
			
		||||
renderer_class_type: LSTMRenderer
 | 
			
		||||
image_feature_extractor_enabled: true
 | 
			
		||||
image_feature_extractor_class_type: ResNetFeatureExtractor
 | 
			
		||||
view_pooler_enabled: true
 | 
			
		||||
implicit_function_class_type: IdrFeatureField
 | 
			
		||||
loss_weights:
 | 
			
		||||
@ -73,7 +73,7 @@ renderer_LSTMRenderer_args:
 | 
			
		||||
  hidden_size: 16
 | 
			
		||||
  n_feature_channels: 256
 | 
			
		||||
  verbose: false
 | 
			
		||||
image_feature_extractor_args:
 | 
			
		||||
image_feature_extractor_ResNetFeatureExtractor_args:
 | 
			
		||||
  name: resnet34
 | 
			
		||||
  pretrained: true
 | 
			
		||||
  stages:
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,9 @@ import unittest
 | 
			
		||||
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
 | 
			
		||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
 | 
			
		||||
    ResNetFeatureExtractor,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.generic_model import GenericModel
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
 | 
			
		||||
    IdrFeatureField,
 | 
			
		||||
@ -63,7 +66,6 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        provide_resnet34()
 | 
			
		||||
        args = get_default_args(GenericModel)
 | 
			
		||||
        args.view_pooler_enabled = True
 | 
			
		||||
        args.image_feature_extractor_enabled = True
 | 
			
		||||
        args.view_pooler_args.feature_aggregator_class_type = (
 | 
			
		||||
            "AngleWeightedIdentityFeatureAggregator"
 | 
			
		||||
        )
 | 
			
		||||
@ -77,9 +79,13 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
 | 
			
		||||
        self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
 | 
			
		||||
        self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
 | 
			
		||||
        self.assertFalse(hasattr(gm, "implicit_function"))
 | 
			
		||||
 | 
			
		||||
        instance_args = OmegaConf.structured(gm)
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
 | 
			
		||||
            (DATA_DIR / "overrides_full.yaml").write_text(full_yaml)
 | 
			
		||||
        remove_unused_components(instance_args)
 | 
			
		||||
        yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        # Simple test of a forward and backward pass of the default GenericModel.
 | 
			
		||||
        device = torch.device("cuda:1")
 | 
			
		||||
        expand_args_fields(GenericModel)
 | 
			
		||||
        model = GenericModel()
 | 
			
		||||
        model = GenericModel(render_image_height=80, render_image_width=80)
 | 
			
		||||
        model.to(device)
 | 
			
		||||
        self._one_model_test(model, device)
 | 
			
		||||
 | 
			
		||||
@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.assertGreater(train_preds["objective"].item(), 0)
 | 
			
		||||
 | 
			
		||||
    def test_viewpool(self):
 | 
			
		||||
        device = torch.device("cuda:1")
 | 
			
		||||
        args = get_default_args(GenericModel)
 | 
			
		||||
        args.view_pooler_enabled = True
 | 
			
		||||
        args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
 | 
			
		||||
        model = GenericModel(**args)
 | 
			
		||||
        model.to(device)
 | 
			
		||||
 | 
			
		||||
        n_train_cameras = 2
 | 
			
		||||
        R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
 | 
			
		||||
        cameras = PerspectiveCameras(R=R, T=T, device=device)
 | 
			
		||||
 | 
			
		||||
        defaulted_args = {
 | 
			
		||||
            "fg_probability": None,
 | 
			
		||||
            "depth_map": None,
 | 
			
		||||
            "mask_crop": None,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        target_image_rgb = torch.rand(
 | 
			
		||||
            (n_train_cameras, 3, model.render_image_height, model.render_image_width),
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        train_preds = model(
 | 
			
		||||
            camera=cameras,
 | 
			
		||||
            evaluation_mode=EvaluationMode.TRAINING,
 | 
			
		||||
            image_rgb=target_image_rgb,
 | 
			
		||||
            sequence_name=["a"] * n_train_cameras,
 | 
			
		||||
            **defaulted_args,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertGreater(train_preds["objective"].item(), 0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _random_input_tensor(
 | 
			
		||||
    N: int,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user