Refactor autodecoders

Summary: Refactors autodecoders. Tests pass.

Reviewed By: bottler

Differential Revision: D37592429

fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
David Novotny 2022-07-04 07:18:03 -07:00 committed by Facebook GitHub Bot
parent ae35824f21
commit 0dce883241
10 changed files with 230 additions and 87 deletions

View File

@ -11,10 +11,12 @@ generic_model_args:
num_passes: 1 num_passes: 1
output_rasterized_mc: true output_rasterized_mc: true
sampling_mode_training: mask_sample sampling_mode_training: mask_sample
sequence_autodecoder_args: global_encoder_class_type: SequenceAutodecoder
n_instances: 20000 global_encoder_SequenceAutodecoder_args:
init_scale: 1.0 autodecoder_args:
encoding_dim: 256 n_instances: 20000
init_scale: 1.0
encoding_dim: 256
implicit_function_IdrFeatureField_args: implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6 n_harmonic_functions_xyz: 6
bias: 0.6 bias: 0.6

View File

@ -4,6 +4,8 @@ defaults:
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pooler_enabled: false view_pooler_enabled: false
sequence_autodecoder_args: global_encoder_class_type: SequenceAutodecoder
n_instances: 20000 global_encoder_SequenceAutodecoder_args:
encoding_dim: 256 autodecoder_args:
n_instances: 20000
encoding_dim: 256

View File

@ -13,9 +13,11 @@ generic_model_args:
loss_prev_stage_mask_bce: 0.0 loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.001 loss_autodecoder_norm: 0.001
depth_neg_penalty: 10000.0 depth_neg_penalty: 10000.0
sequence_autodecoder_args: global_encoder_class_type: SequenceAutodecoder
encoding_dim: 256 global_encoder_SequenceAutodecoder_args:
n_instances: 20000 autodecoder_args:
encoding_dim: 256
n_instances: 20000
raysampler_class_type: NearFarRaySampler raysampler_class_type: NearFarRaySampler
raysampler_NearFarRaySampler_args: raysampler_NearFarRaySampler_args:
n_rays_per_image_sampled_from_mask: 2048 n_rays_per_image_sampled_from_mask: 2048

View File

@ -16,6 +16,7 @@ generic_model_args:
n_train_target_views: 1 n_train_target_views: 1
sampling_mode_training: mask_sample sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid sampling_mode_evaluation: full_grid
global_encoder_class_type: null
raysampler_class_type: AdaptiveRaySampler raysampler_class_type: AdaptiveRaySampler
renderer_class_type: MultiPassEmissionAbsorptionRenderer renderer_class_type: MultiPassEmissionAbsorptionRenderer
image_feature_extractor_class_type: null image_feature_extractor_class_type: null
@ -49,11 +50,16 @@ generic_model_args:
- objective - objective
- epoch - epoch
- sec/it - sec/it
sequence_autodecoder_args: global_encoder_HarmonicTimeEncoder_args:
encoding_dim: 0 n_harmonic_functions: 10
n_instances: 0 append_input: true
init_scale: 1.0 time_divisor: 1.0
ignore_input: false global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args:
image_width: 400 image_width: 400
image_height: 400 image_height: 400

View File

@ -12,7 +12,7 @@ import logging
import math import math
import warnings import warnings
from dataclasses import field from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import tqdm import tqdm
@ -34,10 +34,10 @@ from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom from visdom import Visdom
from .autodecoder import Autodecoder
from .base_model import ImplicitronModelBase, ImplicitronRender from .base_model import ImplicitronModelBase, ImplicitronRender
from .feature_extractor import FeatureExtractorBase from .feature_extractor import FeatureExtractorBase
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
from .global_encoder.global_encoder import GlobalEncoderBase
from .implicit_function.base import ImplicitFunctionBase from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa from .implicit_function.neural_radiance_field import ( # noqa
@ -63,7 +63,6 @@ from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .view_pooler.view_pooler import ViewPooler from .view_pooler.view_pooler import ViewPooler
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -109,6 +108,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
------------------ ------------------
Evaluate the implicit function(s) at the sampled ray points Evaluate the implicit function(s) at the sampled ray points
(optionally pass in the aggregated image features from (4)). (optionally pass in the aggregated image features from (4)).
(also optionally pass in a global encoding from global_encoder).
(6) Rendering (6) Rendering
@ -163,7 +163,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
sampling_mode_training: The sampling method to use during training. Must be sampling_mode_training: The sampling method to use during training. Must be
a value from the RenderSamplingMode Enum. a value from the RenderSamplingMode Enum.
sampling_mode_evaluation: Same as above but for evaluation. sampling_mode_evaluation: Same as above but for evaluation.
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding global_encoder_class_type: The name of the class to use for global_encoder,
which must be available in the registry. Or `None` to disable global encoder.
global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
of the image (referred to as the global_code) that can be used to model aspects of of the image (referred to as the global_code) that can be used to model aspects of
the scene such as multiple objects or morphing objects. It is up to the implicit the scene such as multiple objects or morphing objects. It is up to the implicit
function definition how to use it, but the most typical way is to broadcast and function definition how to use it, but the most typical way is to broadcast and
@ -221,8 +223,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
sampling_mode_training: str = "mask_sample" sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid" sampling_mode_evaluation: str = "full_grid"
# ---- autodecoder settings # ---- global encoder settings
sequence_autodecoder: Autodecoder global_encoder_class_type: Optional[str] = None
global_encoder: Optional[GlobalEncoderBase]
# ---- raysampler # ---- raysampler
raysampler_class_type: str = "AdaptiveRaySampler" raysampler_class_type: str = "AdaptiveRaySampler"
@ -284,7 +287,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
"loss_prev_stage_rgb_psnr_fg", "loss_prev_stage_rgb_psnr_fg",
"loss_prev_stage_rgb_psnr", "loss_prev_stage_rgb_psnr",
"loss_prev_stage_mask_bce", "loss_prev_stage_mask_bce",
*STD_LOG_VARS, # basic metrics
"objective",
"epoch",
"sec/it",
] ]
) )
@ -307,10 +313,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
*, # force keyword-only arguments *, # force keyword-only arguments
image_rgb: Optional[torch.Tensor], image_rgb: Optional[torch.Tensor],
camera: CamerasBase, camera: CamerasBase,
fg_probability: Optional[torch.Tensor], fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor], mask_crop: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor], depth_map: Optional[torch.Tensor] = None,
sequence_name: Optional[List[str]], sequence_name: Optional[List[str]] = None,
frame_timestamp: Optional[torch.Tensor] = None,
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -333,6 +340,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
sequence_name: A list of `B` strings corresponding to the sequence names sequence_name: A list of `B` strings corresponding to the sequence names
from which images `image_rgb` were extracted. They are used to match from which images `image_rgb` were extracted. They are used to match
target frames with relevant source frames. target frames with relevant source frames.
frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
of frame timestamps.
evaluation_mode: one of EvaluationMode.TRAINING or evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for EvaluationMode.EVALUATION which determines the settings used for
rendering. rendering.
@ -357,6 +366,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
else min(self.n_train_target_views, batch_size) else min(self.n_train_target_views, batch_size)
) )
# A helper function for selecting n_target first elements from the input
# where the latter can be None.
def _safe_slice_targets(
tensor: Optional[Union[torch.Tensor, List[str]]],
) -> Optional[Union[torch.Tensor, List[str]]]:
return None if tensor is None else tensor[:n_targets]
# Select the target cameras. # Select the target cameras.
target_cameras = camera[list(range(n_targets))] target_cameras = camera[list(range(n_targets))]
@ -405,10 +421,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
custom_args["fun_viewpool"] = curried_viewpooler custom_args["fun_viewpool"] = curried_viewpooler
global_code = None global_code = None
if self.sequence_autodecoder.n_instances > 0: if self.global_encoder is not None:
if sequence_name is None: global_code = self.global_encoder( # pyre-fixme[29]
raise ValueError("sequence_name must be provided for autodecoder.") sequence_name=_safe_slice_targets(sequence_name),
global_code = self.sequence_autodecoder(sequence_name[:n_targets]) frame_timestamp=_safe_slice_targets(frame_timestamp),
)
custom_args["global_code"] = global_code custom_args["global_code"] = global_code
# pyre-fixme[29]: # pyre-fixme[29]:
@ -447,20 +464,15 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# A dict to store losses as well as rendering results. # A dict to store losses as well as rendering results.
preds: Dict[str, Any] = {} preds: Dict[str, Any] = {}
def safe_slice_targets(
tensor: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
return None if tensor is None else tensor[:n_targets]
preds.update( preds.update(
self.view_metrics( self.view_metrics(
results=preds, results=preds,
raymarched=rendered, raymarched=rendered,
xys=ray_bundle.xys, xys=ray_bundle.xys,
image_rgb=safe_slice_targets(image_rgb), image_rgb=_safe_slice_targets(image_rgb),
depth_map=safe_slice_targets(depth_map), depth_map=_safe_slice_targets(depth_map),
fg_probability=safe_slice_targets(fg_probability), fg_probability=_safe_slice_targets(fg_probability),
mask_crop=safe_slice_targets(mask_crop), mask_crop=_safe_slice_targets(mask_crop),
) )
) )
@ -592,6 +604,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
**kwargs, **kwargs,
) )
def _get_global_encoder_encoding_dim(self) -> int:
if self.global_encoder is None:
return 0
return self.global_encoder.get_encoding_dim()
def _get_viewpooled_feature_dim(self) -> int: def _get_viewpooled_feature_dim(self) -> int:
if self.view_pooler is None: if self.view_pooler is None:
return 0 return 0
@ -668,8 +685,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = ( nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
self._get_viewpooled_feature_dim() self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
+ self.sequence_autodecoder.get_encoding_dim()
) )
nerf_args["color_dim"] = nerformer_args[ nerf_args["color_dim"] = nerformer_args[
"color_dim" "color_dim"
@ -678,21 +694,18 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# idr preprocessing # idr preprocessing
idr = self.implicit_function_IdrFeatureField_args idr = self.implicit_function_IdrFeatureField_args
idr["feature_vector_size"] = self.render_features_dimensions idr["feature_vector_size"] = self.render_features_dimensions
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim() idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
# srn preprocessing # srn preprocessing
srn = self.implicit_function_SRNImplicitFunction_args srn = self.implicit_function_SRNImplicitFunction_args
srn.raymarch_function_args.latent_dim = ( srn.raymarch_function_args.latent_dim = (
self._get_viewpooled_feature_dim() self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
+ self.sequence_autodecoder.get_encoding_dim()
) )
# srn_hypernet preprocessing # srn_hypernet preprocessing
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
srn_hypernet_args = srn_hypernet.hypernet_args srn_hypernet_args = srn_hypernet.hypernet_args
srn_hypernet_args.latent_dim_hypernet = ( srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
self.sequence_autodecoder.get_encoding_dim()
)
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim() srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
# check that for srn, srn_hypernet, idr we have self.num_passes=1 # check that for srn, srn_hypernet, idr we have self.num_passes=1

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -12,10 +12,9 @@ import torch
from pytorch3d.implicitron.tools.config import Configurable from pytorch3d.implicitron.tools.config import Configurable
# TODO: probabilistic embeddings?
class Autodecoder(Configurable, torch.nn.Module): class Autodecoder(Configurable, torch.nn.Module):
""" """
Autodecoder module Autodecoder which maps a list of integer or string keys to optimizable embeddings.
Settings: Settings:
encoding_dim: Embedding dimension for the decoder. encoding_dim: Embedding dimension for the decoder.
@ -43,32 +42,32 @@ class Autodecoder(Configurable, torch.nn.Module):
# weight has been initialised from Normal(0, 1) # weight has been initialised from Normal(0, 1)
self._autodecoder_codes.weight *= self.init_scale self._autodecoder_codes.weight *= self.init_scale
self._sequence_map = self._build_sequence_map() self._key_map = self._build_key_map()
# Make sure to register hooks for correct handling of saving/loading # Make sure to register hooks for correct handling of saving/loading
# the module's _sequence_map. # the module's _key_map.
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook) self._register_load_state_dict_pre_hook(self._load_key_map_hook)
self._register_state_dict_hook(_save_sequence_map_hook) self._register_state_dict_hook(_save_key_map_hook)
def _build_sequence_map( def _build_key_map(
self, sequence_map_dict: Optional[Dict[str, int]] = None self, key_map_dict: Optional[Dict[str, int]] = None
) -> Dict[str, int]: ) -> Dict[str, int]:
""" """
Args: Args:
sequence_map_dict: A dictionary used to initialize the sequence_map. key_map_dict: A dictionary used to initialize the key_map.
Returns: Returns:
sequence_map: a dictionary of key: id pairs. key_map: a dictionary of key: id pairs.
""" """
# increments the counter when asked for a new value # increments the counter when asked for a new value
sequence_map = defaultdict(iter(range(self.n_instances)).__next__) key_map = defaultdict(iter(range(self.n_instances)).__next__)
if sequence_map_dict is not None: if key_map_dict is not None:
# Assign all keys from the loaded sequence_map_dict to self._sequence_map. # Assign all keys from the loaded key_map_dict to self._key_map.
# Since this is done in the original order, it should generate # Since this is done in the original order, it should generate
# the same set of key:id pairs. We check this with an assert to be sure. # the same set of key:id pairs. We check this with an assert to be sure.
for x, x_id in sequence_map_dict.items(): for x, x_id in key_map_dict.items():
x_id_ = sequence_map[x] x_id_ = key_map[x]
assert x_id == x_id_ assert x_id == x_id_
return sequence_map return key_map
def calc_squared_encoding_norm(self): def calc_squared_encoding_norm(self):
if self.n_instances <= 0: if self.n_instances <= 0:
@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module):
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]: def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
""" """
Args: Args:
x: A batch of `N` sequence identifiers. Either a long tensor of size x: A batch of `N` identifiers. Either a long tensor of size
`(N,)` keys in [0, n_instances), or a list of `N` string keys that `(N,)` keys in [0, n_instances), or a list of `N` string keys that
are hashed to codes (without collisions). are hashed to codes (without collisions).
Returns: Returns:
codes: A tensor of shape `(N, self.encoding_dim)` containing the codes: A tensor of shape `(N, self.encoding_dim)` containing the
sequence-specific autodecoder codes. key-specific autodecoder codes.
""" """
if self.n_instances == 0: if self.n_instances == 0:
return None return None
@ -103,7 +102,7 @@ class Autodecoder(Configurable, torch.nn.Module):
# `Tensor`. # `Tensor`.
x = torch.tensor( x = torch.tensor(
# pyre-ignore[29] # pyre-ignore[29]
[self._sequence_map[elem] for elem in x], [self._key_map[elem] for elem in x],
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
@ -113,7 +112,7 @@ class Autodecoder(Configurable, torch.nn.Module):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self._autodecoder_codes(x) return self._autodecoder_codes(x)
def _load_sequence_map_hook( def _load_key_map_hook(
self, self,
state_dict, state_dict,
prefix, prefix,
@ -142,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module):
:meth:`~torch.nn.Module.load_state_dict` :meth:`~torch.nn.Module.load_state_dict`
Returns: Returns:
Constructed sequence_map if it exists in the state_dict Constructed key_map if it exists in the state_dict
else raises a warning only. else raises a warning only.
""" """
sequence_map_key = prefix + "_sequence_map" key_map_key = prefix + "_key_map"
if sequence_map_key in state_dict: if key_map_key in state_dict:
sequence_map_dict = state_dict.pop(sequence_map_key) key_map_dict = state_dict.pop(key_map_key)
self._sequence_map = self._build_sequence_map( self._key_map = self._build_key_map(key_map_dict=key_map_dict)
sequence_map_dict=sequence_map_dict
)
else: else:
warnings.warn("No sequence map in Autodecoder state dict!") warnings.warn("No key map in Autodecoder state dict!")
def _save_sequence_map_hook( def _save_key_map_hook(
self, self,
state_dict, state_dict,
prefix, prefix,
@ -169,6 +166,6 @@ def _save_sequence_map_hook(
module module
local_metadata (dict): a dict containing the metadata for this module. local_metadata (dict): a dict containing the metadata for this module.
""" """
sequence_map_key = prefix + "_sequence_map" key_map_key = prefix + "_key_map"
sequence_map_dict = dict(self._sequence_map.items()) key_map_dict = dict(self._key_map.items())
state_dict[sequence_map_key] = sequence_map_dict state_dict[key_map_key] = key_map_dict

View File

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Union
import torch
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .autodecoder import Autodecoder
class GlobalEncoderBase(ReplaceableBase):
"""
A base class for implementing encoders of global frame-specific quantities.
The latter includes e.g. the harmonic encoding of a frame timestamp
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
(`SequenceAutodecoder`).
"""
def __init__(self) -> None:
super().__init__()
def get_encoding_dim(self):
"""
Returns the dimensionality of the returned encoding.
"""
raise NotImplementedError()
def calc_squared_encoding_norm(self):
"""
Calculates the squared norm of the encoding.
"""
raise NotImplementedError()
def forward(self, **kwargs) -> torch.Tensor:
"""
Given a set of inputs to encode, generates a tensor containing the encoding.
Returns:
encoding: The tensor containing the global encoding.
"""
raise NotImplementedError()
# TODO: probabilistic embeddings?
@registry.register
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
"""
A global encoder implementation which provides an autodecoder encoding
of the frame's sequence identifier.
"""
autodecoder: Autodecoder
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def get_encoding_dim(self):
return self.autodecoder.get_encoding_dim()
def forward(
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
) -> torch.Tensor:
# run dtype checks and pass sequence_name to self.autodecoder
return self.autodecoder(sequence_name)
def calc_squared_encoding_norm(self):
return self.autodecoder.calc_squared_encoding_norm()
@registry.register
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
"""
A global encoder implementation which provides harmonic embeddings
of each frame's timestamp.
"""
n_harmonic_functions: int = 10
append_input: bool = True
time_divisor: float = 1.0
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input,
)
def get_encoding_dim(self):
return self._harmonic_embedding.get_output_dim(1)
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor
return self._harmonic_embedding(time) # pyre-ignore: 29
def calc_squared_encoding_norm(self):
return 0.0

View File

@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
n_train_target_views: 1 n_train_target_views: 1
sampling_mode_training: mask_sample sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid sampling_mode_evaluation: full_grid
global_encoder_class_type: SequenceAutodecoder
raysampler_class_type: AdaptiveRaySampler raysampler_class_type: AdaptiveRaySampler
renderer_class_type: LSTMRenderer renderer_class_type: LSTMRenderer
image_feature_extractor_class_type: ResNetFeatureExtractor image_feature_extractor_class_type: ResNetFeatureExtractor
@ -48,11 +49,12 @@ log_vars:
- objective - objective
- epoch - epoch
- sec/it - sec/it
sequence_autodecoder_args: global_encoder_SequenceAutodecoder_args:
encoding_dim: 0 autodecoder_args:
n_instances: 0 encoding_dim: 0
init_scale: 1.0 n_instances: 0
ignore_input: false init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args:
image_width: 400 image_width: 400
image_height: 400 image_height: 400

View File

@ -7,11 +7,13 @@
import unittest import unittest
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
ResNetFeatureExtractor, ResNetFeatureExtractor,
) )
from pytorch3d.implicitron.models.generic_model import GenericModel from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
SequenceAutodecoder,
)
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField, IdrFeatureField,
) )
@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
self.assertIsInstance( self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
) )
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) self.assertIsNone(gm.global_encoder)
self.assertFalse(hasattr(gm, "implicit_function")) self.assertFalse(hasattr(gm, "implicit_function"))
self.assertIsNone(gm.view_pooler) self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor) self.assertIsNone(gm.image_feature_extractor)
@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
) )
args.image_feature_extractor_class_type = "ResNetFeatureExtractor" args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
args.implicit_function_class_type = "IdrFeatureField" args.implicit_function_class_type = "IdrFeatureField"
args.global_encoder_class_type = "SequenceAutodecoder"
idr_args = args.implicit_function_IdrFeatureField_args idr_args = args.implicit_function_IdrFeatureField_args
idr_args.n_harmonic_functions_xyz = 1729 idr_args.n_harmonic_functions_xyz = 1729
@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
) )
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729) self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function")) self.assertFalse(hasattr(gm, "implicit_function"))
@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
remove_unused_components(instance_args) remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG: if DEBUG:
print(DATA_DIR)
(DATA_DIR / "overrides.yaml_").write_text(yaml) (DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text()) self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())