mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Refactor autodecoders
Summary: Refactors autodecoders. Tests pass. Reviewed By: bottler Differential Revision: D37592429 fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
		
							parent
							
								
									ae35824f21
								
							
						
					
					
						commit
						0dce883241
					
				@ -11,10 +11,12 @@ generic_model_args:
 | 
			
		||||
  num_passes: 1
 | 
			
		||||
  output_rasterized_mc: true
 | 
			
		||||
  sampling_mode_training: mask_sample
 | 
			
		||||
  sequence_autodecoder_args:
 | 
			
		||||
    n_instances: 20000
 | 
			
		||||
    init_scale: 1.0
 | 
			
		||||
    encoding_dim: 256
 | 
			
		||||
  global_encoder_class_type: SequenceAutodecoder
 | 
			
		||||
  global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
    autodecoder_args:
 | 
			
		||||
      n_instances: 20000
 | 
			
		||||
      init_scale: 1.0
 | 
			
		||||
      encoding_dim: 256
 | 
			
		||||
  implicit_function_IdrFeatureField_args:
 | 
			
		||||
    n_harmonic_functions_xyz: 6
 | 
			
		||||
    bias: 0.6
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,8 @@ defaults:
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  chunk_size_grid: 16000
 | 
			
		||||
  view_pooler_enabled: false
 | 
			
		||||
  sequence_autodecoder_args:
 | 
			
		||||
    n_instances: 20000
 | 
			
		||||
    encoding_dim: 256
 | 
			
		||||
  global_encoder_class_type: SequenceAutodecoder
 | 
			
		||||
  global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
    autodecoder_args:
 | 
			
		||||
      n_instances: 20000
 | 
			
		||||
      encoding_dim: 256
 | 
			
		||||
 | 
			
		||||
@ -13,9 +13,11 @@ generic_model_args:
 | 
			
		||||
    loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
    loss_autodecoder_norm: 0.001
 | 
			
		||||
    depth_neg_penalty: 10000.0
 | 
			
		||||
  sequence_autodecoder_args:
 | 
			
		||||
    encoding_dim: 256
 | 
			
		||||
    n_instances: 20000
 | 
			
		||||
  global_encoder_class_type: SequenceAutodecoder
 | 
			
		||||
  global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
    autodecoder_args:
 | 
			
		||||
      encoding_dim: 256
 | 
			
		||||
      n_instances: 20000
 | 
			
		||||
  raysampler_class_type: NearFarRaySampler
 | 
			
		||||
  raysampler_NearFarRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 2048
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ generic_model_args:
 | 
			
		||||
  n_train_target_views: 1
 | 
			
		||||
  sampling_mode_training: mask_sample
 | 
			
		||||
  sampling_mode_evaluation: full_grid
 | 
			
		||||
  global_encoder_class_type: null
 | 
			
		||||
  raysampler_class_type: AdaptiveRaySampler
 | 
			
		||||
  renderer_class_type: MultiPassEmissionAbsorptionRenderer
 | 
			
		||||
  image_feature_extractor_class_type: null
 | 
			
		||||
@ -49,11 +50,16 @@ generic_model_args:
 | 
			
		||||
  - objective
 | 
			
		||||
  - epoch
 | 
			
		||||
  - sec/it
 | 
			
		||||
  sequence_autodecoder_args:
 | 
			
		||||
    encoding_dim: 0
 | 
			
		||||
    n_instances: 0
 | 
			
		||||
    init_scale: 1.0
 | 
			
		||||
    ignore_input: false
 | 
			
		||||
  global_encoder_HarmonicTimeEncoder_args:
 | 
			
		||||
    n_harmonic_functions: 10
 | 
			
		||||
    append_input: true
 | 
			
		||||
    time_divisor: 1.0
 | 
			
		||||
  global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
    autodecoder_args:
 | 
			
		||||
      encoding_dim: 0
 | 
			
		||||
      n_instances: 0
 | 
			
		||||
      init_scale: 1.0
 | 
			
		||||
      ignore_input: false
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    image_width: 400
 | 
			
		||||
    image_height: 400
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ import logging
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from dataclasses import field
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
@ -34,10 +34,10 @@ from pytorch3d.renderer import RayBundle, utils as rend_utils
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
from .autodecoder import Autodecoder
 | 
			
		||||
from .base_model import ImplicitronModelBase, ImplicitronRender
 | 
			
		||||
from .feature_extractor import FeatureExtractorBase
 | 
			
		||||
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor  # noqa
 | 
			
		||||
from .global_encoder.global_encoder import GlobalEncoderBase
 | 
			
		||||
from .implicit_function.base import ImplicitFunctionBase
 | 
			
		||||
from .implicit_function.idr_feature_field import IdrFeatureField  # noqa
 | 
			
		||||
from .implicit_function.neural_radiance_field import (  # noqa
 | 
			
		||||
@ -63,7 +63,6 @@ from .renderer.sdf_renderer import SignedDistanceFunctionRenderer  # noqa
 | 
			
		||||
from .view_pooler.view_pooler import ViewPooler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -109,6 +108,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        ------------------
 | 
			
		||||
        Evaluate the implicit function(s) at the sampled ray points
 | 
			
		||||
        (optionally pass in the aggregated image features from (4)).
 | 
			
		||||
        (also optionally pass in a global encoding from global_encoder).
 | 
			
		||||
                │
 | 
			
		||||
                ▼
 | 
			
		||||
        (6) Rendering
 | 
			
		||||
@ -163,7 +163,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        sampling_mode_training: The sampling method to use during training. Must be
 | 
			
		||||
            a value from the RenderSamplingMode Enum.
 | 
			
		||||
        sampling_mode_evaluation: Same as above but for evaluation.
 | 
			
		||||
        sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
 | 
			
		||||
        global_encoder_class_type: The name of the class to use for global_encoder,
 | 
			
		||||
            which must be available in the registry. Or `None` to disable global encoder.
 | 
			
		||||
        global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
 | 
			
		||||
            of the image (referred to as the global_code) that can be used to model aspects of
 | 
			
		||||
            the scene such as multiple objects or morphing objects. It is up to the implicit
 | 
			
		||||
            function definition how to use it, but the most typical way is to broadcast and
 | 
			
		||||
@ -221,8 +223,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
    sampling_mode_training: str = "mask_sample"
 | 
			
		||||
    sampling_mode_evaluation: str = "full_grid"
 | 
			
		||||
 | 
			
		||||
    # ---- autodecoder settings
 | 
			
		||||
    sequence_autodecoder: Autodecoder
 | 
			
		||||
    # ---- global encoder settings
 | 
			
		||||
    global_encoder_class_type: Optional[str] = None
 | 
			
		||||
    global_encoder: Optional[GlobalEncoderBase]
 | 
			
		||||
 | 
			
		||||
    # ---- raysampler
 | 
			
		||||
    raysampler_class_type: str = "AdaptiveRaySampler"
 | 
			
		||||
@ -284,7 +287,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            "loss_prev_stage_rgb_psnr_fg",
 | 
			
		||||
            "loss_prev_stage_rgb_psnr",
 | 
			
		||||
            "loss_prev_stage_mask_bce",
 | 
			
		||||
            *STD_LOG_VARS,
 | 
			
		||||
            # basic metrics
 | 
			
		||||
            "objective",
 | 
			
		||||
            "epoch",
 | 
			
		||||
            "sec/it",
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -307,10 +313,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        *,  # force keyword-only arguments
 | 
			
		||||
        image_rgb: Optional[torch.Tensor],
 | 
			
		||||
        camera: CamerasBase,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
        mask_crop: Optional[torch.Tensor],
 | 
			
		||||
        depth_map: Optional[torch.Tensor],
 | 
			
		||||
        sequence_name: Optional[List[str]],
 | 
			
		||||
        fg_probability: Optional[torch.Tensor] = None,
 | 
			
		||||
        mask_crop: Optional[torch.Tensor] = None,
 | 
			
		||||
        depth_map: Optional[torch.Tensor] = None,
 | 
			
		||||
        sequence_name: Optional[List[str]] = None,
 | 
			
		||||
        frame_timestamp: Optional[torch.Tensor] = None,
 | 
			
		||||
        evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
@ -333,6 +340,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            sequence_name: A list of `B` strings corresponding to the sequence names
 | 
			
		||||
                from which images `image_rgb` were extracted. They are used to match
 | 
			
		||||
                target frames with relevant source frames.
 | 
			
		||||
            frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
 | 
			
		||||
                of frame timestamps.
 | 
			
		||||
            evaluation_mode: one of EvaluationMode.TRAINING or
 | 
			
		||||
                EvaluationMode.EVALUATION which determines the settings used for
 | 
			
		||||
                rendering.
 | 
			
		||||
@ -357,6 +366,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            else min(self.n_train_target_views, batch_size)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # A helper function for selecting n_target first elements from the input
 | 
			
		||||
        # where the latter can be None.
 | 
			
		||||
        def _safe_slice_targets(
 | 
			
		||||
            tensor: Optional[Union[torch.Tensor, List[str]]],
 | 
			
		||||
        ) -> Optional[Union[torch.Tensor, List[str]]]:
 | 
			
		||||
            return None if tensor is None else tensor[:n_targets]
 | 
			
		||||
 | 
			
		||||
        # Select the target cameras.
 | 
			
		||||
        target_cameras = camera[list(range(n_targets))]
 | 
			
		||||
 | 
			
		||||
@ -405,10 +421,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            custom_args["fun_viewpool"] = curried_viewpooler
 | 
			
		||||
 | 
			
		||||
        global_code = None
 | 
			
		||||
        if self.sequence_autodecoder.n_instances > 0:
 | 
			
		||||
            if sequence_name is None:
 | 
			
		||||
                raise ValueError("sequence_name must be provided for autodecoder.")
 | 
			
		||||
            global_code = self.sequence_autodecoder(sequence_name[:n_targets])
 | 
			
		||||
        if self.global_encoder is not None:
 | 
			
		||||
            global_code = self.global_encoder(  # pyre-fixme[29]
 | 
			
		||||
                sequence_name=_safe_slice_targets(sequence_name),
 | 
			
		||||
                frame_timestamp=_safe_slice_targets(frame_timestamp),
 | 
			
		||||
            )
 | 
			
		||||
        custom_args["global_code"] = global_code
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
@ -447,20 +464,15 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        # A dict to store losses as well as rendering results.
 | 
			
		||||
        preds: Dict[str, Any] = {}
 | 
			
		||||
 | 
			
		||||
        def safe_slice_targets(
 | 
			
		||||
            tensor: Optional[torch.Tensor],
 | 
			
		||||
        ) -> Optional[torch.Tensor]:
 | 
			
		||||
            return None if tensor is None else tensor[:n_targets]
 | 
			
		||||
 | 
			
		||||
        preds.update(
 | 
			
		||||
            self.view_metrics(
 | 
			
		||||
                results=preds,
 | 
			
		||||
                raymarched=rendered,
 | 
			
		||||
                xys=ray_bundle.xys,
 | 
			
		||||
                image_rgb=safe_slice_targets(image_rgb),
 | 
			
		||||
                depth_map=safe_slice_targets(depth_map),
 | 
			
		||||
                fg_probability=safe_slice_targets(fg_probability),
 | 
			
		||||
                mask_crop=safe_slice_targets(mask_crop),
 | 
			
		||||
                image_rgb=_safe_slice_targets(image_rgb),
 | 
			
		||||
                depth_map=_safe_slice_targets(depth_map),
 | 
			
		||||
                fg_probability=_safe_slice_targets(fg_probability),
 | 
			
		||||
                mask_crop=_safe_slice_targets(mask_crop),
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -592,6 +604,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
                **kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _get_global_encoder_encoding_dim(self) -> int:
 | 
			
		||||
        if self.global_encoder is None:
 | 
			
		||||
            return 0
 | 
			
		||||
        return self.global_encoder.get_encoding_dim()
 | 
			
		||||
 | 
			
		||||
    def _get_viewpooled_feature_dim(self) -> int:
 | 
			
		||||
        if self.view_pooler is None:
 | 
			
		||||
            return 0
 | 
			
		||||
@ -668,8 +685,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
 | 
			
		||||
        nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
 | 
			
		||||
        nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
 | 
			
		||||
            self._get_viewpooled_feature_dim()
 | 
			
		||||
            + self.sequence_autodecoder.get_encoding_dim()
 | 
			
		||||
            self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
 | 
			
		||||
        )
 | 
			
		||||
        nerf_args["color_dim"] = nerformer_args[
 | 
			
		||||
            "color_dim"
 | 
			
		||||
@ -678,21 +694,18 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        # idr preprocessing
 | 
			
		||||
        idr = self.implicit_function_IdrFeatureField_args
 | 
			
		||||
        idr["feature_vector_size"] = self.render_features_dimensions
 | 
			
		||||
        idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
 | 
			
		||||
        idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
 | 
			
		||||
 | 
			
		||||
        # srn preprocessing
 | 
			
		||||
        srn = self.implicit_function_SRNImplicitFunction_args
 | 
			
		||||
        srn.raymarch_function_args.latent_dim = (
 | 
			
		||||
            self._get_viewpooled_feature_dim()
 | 
			
		||||
            + self.sequence_autodecoder.get_encoding_dim()
 | 
			
		||||
            self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # srn_hypernet preprocessing
 | 
			
		||||
        srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
 | 
			
		||||
        srn_hypernet_args = srn_hypernet.hypernet_args
 | 
			
		||||
        srn_hypernet_args.latent_dim_hypernet = (
 | 
			
		||||
            self.sequence_autodecoder.get_encoding_dim()
 | 
			
		||||
        )
 | 
			
		||||
        srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
 | 
			
		||||
        srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
 | 
			
		||||
 | 
			
		||||
        # check that for srn, srn_hypernet, idr we have self.num_passes=1
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								pytorch3d/implicitron/models/global_encoder/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								pytorch3d/implicitron/models/global_encoder/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
@ -12,10 +12,9 @@ import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import Configurable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: probabilistic embeddings?
 | 
			
		||||
class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Autodecoder module
 | 
			
		||||
    Autodecoder which maps a list of integer or string keys to optimizable embeddings.
 | 
			
		||||
 | 
			
		||||
    Settings:
 | 
			
		||||
        encoding_dim: Embedding dimension for the decoder.
 | 
			
		||||
@ -43,32 +42,32 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
            # weight has been initialised from Normal(0, 1)
 | 
			
		||||
            self._autodecoder_codes.weight *= self.init_scale
 | 
			
		||||
 | 
			
		||||
        self._sequence_map = self._build_sequence_map()
 | 
			
		||||
        self._key_map = self._build_key_map()
 | 
			
		||||
        # Make sure to register hooks for correct handling of saving/loading
 | 
			
		||||
        # the module's _sequence_map.
 | 
			
		||||
        self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
 | 
			
		||||
        self._register_state_dict_hook(_save_sequence_map_hook)
 | 
			
		||||
        # the module's _key_map.
 | 
			
		||||
        self._register_load_state_dict_pre_hook(self._load_key_map_hook)
 | 
			
		||||
        self._register_state_dict_hook(_save_key_map_hook)
 | 
			
		||||
 | 
			
		||||
    def _build_sequence_map(
 | 
			
		||||
        self, sequence_map_dict: Optional[Dict[str, int]] = None
 | 
			
		||||
    def _build_key_map(
 | 
			
		||||
        self, key_map_dict: Optional[Dict[str, int]] = None
 | 
			
		||||
    ) -> Dict[str, int]:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            sequence_map_dict: A dictionary used to initialize the sequence_map.
 | 
			
		||||
            key_map_dict: A dictionary used to initialize the key_map.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            sequence_map: a dictionary of key: id pairs.
 | 
			
		||||
            key_map: a dictionary of key: id pairs.
 | 
			
		||||
        """
 | 
			
		||||
        # increments the counter when asked for a new value
 | 
			
		||||
        sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
 | 
			
		||||
        if sequence_map_dict is not None:
 | 
			
		||||
            # Assign all keys from the loaded sequence_map_dict to self._sequence_map.
 | 
			
		||||
        key_map = defaultdict(iter(range(self.n_instances)).__next__)
 | 
			
		||||
        if key_map_dict is not None:
 | 
			
		||||
            # Assign all keys from the loaded key_map_dict to self._key_map.
 | 
			
		||||
            # Since this is done in the original order, it should generate
 | 
			
		||||
            # the same set of key:id pairs. We check this with an assert to be sure.
 | 
			
		||||
            for x, x_id in sequence_map_dict.items():
 | 
			
		||||
                x_id_ = sequence_map[x]
 | 
			
		||||
            for x, x_id in key_map_dict.items():
 | 
			
		||||
                x_id_ = key_map[x]
 | 
			
		||||
                assert x_id == x_id_
 | 
			
		||||
        return sequence_map
 | 
			
		||||
        return key_map
 | 
			
		||||
 | 
			
		||||
    def calc_squared_encoding_norm(self):
 | 
			
		||||
        if self.n_instances <= 0:
 | 
			
		||||
@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
    def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            x: A batch of `N` sequence identifiers. Either a long tensor of size
 | 
			
		||||
            x: A batch of `N` identifiers. Either a long tensor of size
 | 
			
		||||
            `(N,)` keys in [0, n_instances), or a list of `N` string keys that
 | 
			
		||||
            are hashed to codes (without collisions).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            codes: A tensor of shape `(N, self.encoding_dim)` containing the
 | 
			
		||||
                sequence-specific autodecoder codes.
 | 
			
		||||
                key-specific autodecoder codes.
 | 
			
		||||
        """
 | 
			
		||||
        if self.n_instances == 0:
 | 
			
		||||
            return None
 | 
			
		||||
@ -103,7 +102,7 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
                #  `Tensor`.
 | 
			
		||||
                x = torch.tensor(
 | 
			
		||||
                    # pyre-ignore[29]
 | 
			
		||||
                    [self._sequence_map[elem] for elem in x],
 | 
			
		||||
                    [self._key_map[elem] for elem in x],
 | 
			
		||||
                    dtype=torch.long,
 | 
			
		||||
                    device=next(self.parameters()).device,
 | 
			
		||||
                )
 | 
			
		||||
@ -113,7 +112,7 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        return self._autodecoder_codes(x)
 | 
			
		||||
 | 
			
		||||
    def _load_sequence_map_hook(
 | 
			
		||||
    def _load_key_map_hook(
 | 
			
		||||
        self,
 | 
			
		||||
        state_dict,
 | 
			
		||||
        prefix,
 | 
			
		||||
@ -142,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
                :meth:`~torch.nn.Module.load_state_dict`
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Constructed sequence_map if it exists in the state_dict
 | 
			
		||||
            Constructed key_map if it exists in the state_dict
 | 
			
		||||
            else raises a warning only.
 | 
			
		||||
        """
 | 
			
		||||
        sequence_map_key = prefix + "_sequence_map"
 | 
			
		||||
        if sequence_map_key in state_dict:
 | 
			
		||||
            sequence_map_dict = state_dict.pop(sequence_map_key)
 | 
			
		||||
            self._sequence_map = self._build_sequence_map(
 | 
			
		||||
                sequence_map_dict=sequence_map_dict
 | 
			
		||||
            )
 | 
			
		||||
        key_map_key = prefix + "_key_map"
 | 
			
		||||
        if key_map_key in state_dict:
 | 
			
		||||
            key_map_dict = state_dict.pop(key_map_key)
 | 
			
		||||
            self._key_map = self._build_key_map(key_map_dict=key_map_dict)
 | 
			
		||||
        else:
 | 
			
		||||
            warnings.warn("No sequence map in Autodecoder state dict!")
 | 
			
		||||
            warnings.warn("No key map in Autodecoder state dict!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _save_sequence_map_hook(
 | 
			
		||||
def _save_key_map_hook(
 | 
			
		||||
    self,
 | 
			
		||||
    state_dict,
 | 
			
		||||
    prefix,
 | 
			
		||||
@ -169,6 +166,6 @@ def _save_sequence_map_hook(
 | 
			
		||||
            module
 | 
			
		||||
        local_metadata (dict): a dict containing the metadata for this module.
 | 
			
		||||
    """
 | 
			
		||||
    sequence_map_key = prefix + "_sequence_map"
 | 
			
		||||
    sequence_map_dict = dict(self._sequence_map.items())
 | 
			
		||||
    state_dict[sequence_map_key] = sequence_map_dict
 | 
			
		||||
    key_map_key = prefix + "_key_map"
 | 
			
		||||
    key_map_dict = dict(self._key_map.items())
 | 
			
		||||
    state_dict[key_map_key] = key_map_dict
 | 
			
		||||
							
								
								
									
										110
									
								
								pytorch3d/implicitron/models/global_encoder/global_encoder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								pytorch3d/implicitron/models/global_encoder/global_encoder.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,110 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import List, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    registry,
 | 
			
		||||
    ReplaceableBase,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
 | 
			
		||||
 | 
			
		||||
from .autodecoder import Autodecoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GlobalEncoderBase(ReplaceableBase):
 | 
			
		||||
    """
 | 
			
		||||
    A base class for implementing encoders of global frame-specific quantities.
 | 
			
		||||
 | 
			
		||||
    The latter includes e.g. the harmonic encoding of a frame timestamp
 | 
			
		||||
    (`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
 | 
			
		||||
    (`SequenceAutodecoder`).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_encoding_dim(self):
 | 
			
		||||
        """
 | 
			
		||||
        Returns the dimensionality of the returned encoding.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def calc_squared_encoding_norm(self):
 | 
			
		||||
        """
 | 
			
		||||
        Calculates the squared norm of the encoding.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def forward(self, **kwargs) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Given a set of inputs to encode, generates a tensor containing the encoding.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            encoding: The tensor containing the global encoding.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: probabilistic embeddings?
 | 
			
		||||
@registry.register
 | 
			
		||||
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
    """
 | 
			
		||||
    A global encoder implementation which provides an autodecoder encoding
 | 
			
		||||
    of the frame's sequence identifier.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    autodecoder: Autodecoder
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def get_encoding_dim(self):
 | 
			
		||||
        return self.autodecoder.get_encoding_dim()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
        # run dtype checks and pass sequence_name to self.autodecoder
 | 
			
		||||
        return self.autodecoder(sequence_name)
 | 
			
		||||
 | 
			
		||||
    def calc_squared_encoding_norm(self):
 | 
			
		||||
        return self.autodecoder.calc_squared_encoding_norm()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    A global encoder implementation which provides harmonic embeddings
 | 
			
		||||
    of each frame's timestamp.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    n_harmonic_functions: int = 10
 | 
			
		||||
    append_input: bool = True
 | 
			
		||||
    time_divisor: float = 1.0
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._harmonic_embedding = HarmonicEmbedding(
 | 
			
		||||
            n_harmonic_functions=self.n_harmonic_functions,
 | 
			
		||||
            append_input=self.append_input,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_encoding_dim(self):
 | 
			
		||||
        return self._harmonic_embedding.get_output_dim(1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
 | 
			
		||||
        if frame_timestamp.shape[-1] != 1:
 | 
			
		||||
            raise ValueError("Frame timestamp's last dimensions should be one.")
 | 
			
		||||
        time = frame_timestamp / self.time_divisor
 | 
			
		||||
        return self._harmonic_embedding(time)  # pyre-ignore: 29
 | 
			
		||||
 | 
			
		||||
    def calc_squared_encoding_norm(self):
 | 
			
		||||
        return 0.0
 | 
			
		||||
@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
 | 
			
		||||
n_train_target_views: 1
 | 
			
		||||
sampling_mode_training: mask_sample
 | 
			
		||||
sampling_mode_evaluation: full_grid
 | 
			
		||||
global_encoder_class_type: SequenceAutodecoder
 | 
			
		||||
raysampler_class_type: AdaptiveRaySampler
 | 
			
		||||
renderer_class_type: LSTMRenderer
 | 
			
		||||
image_feature_extractor_class_type: ResNetFeatureExtractor
 | 
			
		||||
@ -48,11 +49,12 @@ log_vars:
 | 
			
		||||
- objective
 | 
			
		||||
- epoch
 | 
			
		||||
- sec/it
 | 
			
		||||
sequence_autodecoder_args:
 | 
			
		||||
  encoding_dim: 0
 | 
			
		||||
  n_instances: 0
 | 
			
		||||
  init_scale: 1.0
 | 
			
		||||
  ignore_input: false
 | 
			
		||||
global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
  autodecoder_args:
 | 
			
		||||
    encoding_dim: 0
 | 
			
		||||
    n_instances: 0
 | 
			
		||||
    init_scale: 1.0
 | 
			
		||||
    ignore_input: false
 | 
			
		||||
raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
  image_width: 400
 | 
			
		||||
  image_height: 400
 | 
			
		||||
 | 
			
		||||
@ -7,11 +7,13 @@
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
 | 
			
		||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
 | 
			
		||||
    ResNetFeatureExtractor,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.generic_model import GenericModel
 | 
			
		||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
 | 
			
		||||
    SequenceAutodecoder,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
 | 
			
		||||
    IdrFeatureField,
 | 
			
		||||
)
 | 
			
		||||
@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        self.assertIsInstance(
 | 
			
		||||
            gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
 | 
			
		||||
        self.assertIsNone(gm.global_encoder)
 | 
			
		||||
        self.assertFalse(hasattr(gm, "implicit_function"))
 | 
			
		||||
        self.assertIsNone(gm.view_pooler)
 | 
			
		||||
        self.assertIsNone(gm.image_feature_extractor)
 | 
			
		||||
@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
 | 
			
		||||
        args.implicit_function_class_type = "IdrFeatureField"
 | 
			
		||||
        args.global_encoder_class_type = "SequenceAutodecoder"
 | 
			
		||||
        idr_args = args.implicit_function_IdrFeatureField_args
 | 
			
		||||
        idr_args.n_harmonic_functions_xyz = 1729
 | 
			
		||||
 | 
			
		||||
@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
 | 
			
		||||
        self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
 | 
			
		||||
        self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
 | 
			
		||||
        self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
 | 
			
		||||
        self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
 | 
			
		||||
        self.assertFalse(hasattr(gm, "implicit_function"))
 | 
			
		||||
 | 
			
		||||
@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        remove_unused_components(instance_args)
 | 
			
		||||
        yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            print(DATA_DIR)
 | 
			
		||||
            (DATA_DIR / "overrides.yaml_").write_text(yaml)
 | 
			
		||||
        self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user