mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Refactor autodecoders
Summary: Refactors autodecoders. Tests pass. Reviewed By: bottler Differential Revision: D37592429 fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
parent
ae35824f21
commit
0dce883241
@ -11,10 +11,12 @@ generic_model_args:
|
||||
num_passes: 1
|
||||
output_rasterized_mc: true
|
||||
sampling_mode_training: mask_sample
|
||||
sequence_autodecoder_args:
|
||||
n_instances: 20000
|
||||
init_scale: 1.0
|
||||
encoding_dim: 256
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
n_instances: 20000
|
||||
init_scale: 1.0
|
||||
encoding_dim: 256
|
||||
implicit_function_IdrFeatureField_args:
|
||||
n_harmonic_functions_xyz: 6
|
||||
bias: 0.6
|
||||
|
@ -4,6 +4,8 @@ defaults:
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pooler_enabled: false
|
||||
sequence_autodecoder_args:
|
||||
n_instances: 20000
|
||||
encoding_dim: 256
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
n_instances: 20000
|
||||
encoding_dim: 256
|
||||
|
@ -13,9 +13,11 @@ generic_model_args:
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
loss_autodecoder_norm: 0.001
|
||||
depth_neg_penalty: 10000.0
|
||||
sequence_autodecoder_args:
|
||||
encoding_dim: 256
|
||||
n_instances: 20000
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 256
|
||||
n_instances: 20000
|
||||
raysampler_class_type: NearFarRaySampler
|
||||
raysampler_NearFarRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 2048
|
||||
|
@ -16,6 +16,7 @@ generic_model_args:
|
||||
n_train_target_views: 1
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
global_encoder_class_type: null
|
||||
raysampler_class_type: AdaptiveRaySampler
|
||||
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
||||
image_feature_extractor_class_type: null
|
||||
@ -49,11 +50,16 @@ generic_model_args:
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
sequence_autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
global_encoder_HarmonicTimeEncoder_args:
|
||||
n_harmonic_functions: 10
|
||||
append_input: true
|
||||
time_divisor: 1.0
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
|
@ -12,7 +12,7 @@ import logging
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
@ -34,10 +34,10 @@ from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from visdom import Visdom
|
||||
|
||||
from .autodecoder import Autodecoder
|
||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||
from .feature_extractor import FeatureExtractorBase
|
||||
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
|
||||
from .global_encoder.global_encoder import GlobalEncoderBase
|
||||
from .implicit_function.base import ImplicitFunctionBase
|
||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||
from .implicit_function.neural_radiance_field import ( # noqa
|
||||
@ -63,7 +63,6 @@ from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
||||
from .view_pooler.view_pooler import ViewPooler
|
||||
|
||||
|
||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -109,6 +108,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
------------------
|
||||
Evaluate the implicit function(s) at the sampled ray points
|
||||
(optionally pass in the aggregated image features from (4)).
|
||||
(also optionally pass in a global encoding from global_encoder).
|
||||
│
|
||||
▼
|
||||
(6) Rendering
|
||||
@ -163,7 +163,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
sampling_mode_training: The sampling method to use during training. Must be
|
||||
a value from the RenderSamplingMode Enum.
|
||||
sampling_mode_evaluation: Same as above but for evaluation.
|
||||
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
|
||||
global_encoder_class_type: The name of the class to use for global_encoder,
|
||||
which must be available in the registry. Or `None` to disable global encoder.
|
||||
global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
|
||||
of the image (referred to as the global_code) that can be used to model aspects of
|
||||
the scene such as multiple objects or morphing objects. It is up to the implicit
|
||||
function definition how to use it, but the most typical way is to broadcast and
|
||||
@ -221,8 +223,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
sampling_mode_training: str = "mask_sample"
|
||||
sampling_mode_evaluation: str = "full_grid"
|
||||
|
||||
# ---- autodecoder settings
|
||||
sequence_autodecoder: Autodecoder
|
||||
# ---- global encoder settings
|
||||
global_encoder_class_type: Optional[str] = None
|
||||
global_encoder: Optional[GlobalEncoderBase]
|
||||
|
||||
# ---- raysampler
|
||||
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||
@ -284,7 +287,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
"loss_prev_stage_rgb_psnr_fg",
|
||||
"loss_prev_stage_rgb_psnr",
|
||||
"loss_prev_stage_mask_bce",
|
||||
*STD_LOG_VARS,
|
||||
# basic metrics
|
||||
"objective",
|
||||
"epoch",
|
||||
"sec/it",
|
||||
]
|
||||
)
|
||||
|
||||
@ -307,10 +313,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
*, # force keyword-only arguments
|
||||
image_rgb: Optional[torch.Tensor],
|
||||
camera: CamerasBase,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
mask_crop: Optional[torch.Tensor],
|
||||
depth_map: Optional[torch.Tensor],
|
||||
sequence_name: Optional[List[str]],
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
sequence_name: Optional[List[str]] = None,
|
||||
frame_timestamp: Optional[torch.Tensor] = None,
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
@ -333,6 +340,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||
from which images `image_rgb` were extracted. They are used to match
|
||||
target frames with relevant source frames.
|
||||
frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
|
||||
of frame timestamps.
|
||||
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||
EvaluationMode.EVALUATION which determines the settings used for
|
||||
rendering.
|
||||
@ -357,6 +366,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
else min(self.n_train_target_views, batch_size)
|
||||
)
|
||||
|
||||
# A helper function for selecting n_target first elements from the input
|
||||
# where the latter can be None.
|
||||
def _safe_slice_targets(
|
||||
tensor: Optional[Union[torch.Tensor, List[str]]],
|
||||
) -> Optional[Union[torch.Tensor, List[str]]]:
|
||||
return None if tensor is None else tensor[:n_targets]
|
||||
|
||||
# Select the target cameras.
|
||||
target_cameras = camera[list(range(n_targets))]
|
||||
|
||||
@ -405,10 +421,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
custom_args["fun_viewpool"] = curried_viewpooler
|
||||
|
||||
global_code = None
|
||||
if self.sequence_autodecoder.n_instances > 0:
|
||||
if sequence_name is None:
|
||||
raise ValueError("sequence_name must be provided for autodecoder.")
|
||||
global_code = self.sequence_autodecoder(sequence_name[:n_targets])
|
||||
if self.global_encoder is not None:
|
||||
global_code = self.global_encoder( # pyre-fixme[29]
|
||||
sequence_name=_safe_slice_targets(sequence_name),
|
||||
frame_timestamp=_safe_slice_targets(frame_timestamp),
|
||||
)
|
||||
custom_args["global_code"] = global_code
|
||||
|
||||
# pyre-fixme[29]:
|
||||
@ -447,20 +464,15 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
# A dict to store losses as well as rendering results.
|
||||
preds: Dict[str, Any] = {}
|
||||
|
||||
def safe_slice_targets(
|
||||
tensor: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
return None if tensor is None else tensor[:n_targets]
|
||||
|
||||
preds.update(
|
||||
self.view_metrics(
|
||||
results=preds,
|
||||
raymarched=rendered,
|
||||
xys=ray_bundle.xys,
|
||||
image_rgb=safe_slice_targets(image_rgb),
|
||||
depth_map=safe_slice_targets(depth_map),
|
||||
fg_probability=safe_slice_targets(fg_probability),
|
||||
mask_crop=safe_slice_targets(mask_crop),
|
||||
image_rgb=_safe_slice_targets(image_rgb),
|
||||
depth_map=_safe_slice_targets(depth_map),
|
||||
fg_probability=_safe_slice_targets(fg_probability),
|
||||
mask_crop=_safe_slice_targets(mask_crop),
|
||||
)
|
||||
)
|
||||
|
||||
@ -592,6 +604,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_global_encoder_encoding_dim(self) -> int:
|
||||
if self.global_encoder is None:
|
||||
return 0
|
||||
return self.global_encoder.get_encoding_dim()
|
||||
|
||||
def _get_viewpooled_feature_dim(self) -> int:
|
||||
if self.view_pooler is None:
|
||||
return 0
|
||||
@ -668,8 +685,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
||||
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
||||
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
||||
self._get_viewpooled_feature_dim()
|
||||
+ self.sequence_autodecoder.get_encoding_dim()
|
||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||
)
|
||||
nerf_args["color_dim"] = nerformer_args[
|
||||
"color_dim"
|
||||
@ -678,21 +694,18 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
# idr preprocessing
|
||||
idr = self.implicit_function_IdrFeatureField_args
|
||||
idr["feature_vector_size"] = self.render_features_dimensions
|
||||
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
|
||||
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||
|
||||
# srn preprocessing
|
||||
srn = self.implicit_function_SRNImplicitFunction_args
|
||||
srn.raymarch_function_args.latent_dim = (
|
||||
self._get_viewpooled_feature_dim()
|
||||
+ self.sequence_autodecoder.get_encoding_dim()
|
||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||
)
|
||||
|
||||
# srn_hypernet preprocessing
|
||||
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
||||
srn_hypernet_args = srn_hypernet.hypernet_args
|
||||
srn_hypernet_args.latent_dim_hypernet = (
|
||||
self.sequence_autodecoder.get_encoding_dim()
|
||||
)
|
||||
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
|
||||
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
||||
|
||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||
|
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
@ -12,10 +12,9 @@ import torch
|
||||
from pytorch3d.implicitron.tools.config import Configurable
|
||||
|
||||
|
||||
# TODO: probabilistic embeddings?
|
||||
class Autodecoder(Configurable, torch.nn.Module):
|
||||
"""
|
||||
Autodecoder module
|
||||
Autodecoder which maps a list of integer or string keys to optimizable embeddings.
|
||||
|
||||
Settings:
|
||||
encoding_dim: Embedding dimension for the decoder.
|
||||
@ -43,32 +42,32 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
# weight has been initialised from Normal(0, 1)
|
||||
self._autodecoder_codes.weight *= self.init_scale
|
||||
|
||||
self._sequence_map = self._build_sequence_map()
|
||||
self._key_map = self._build_key_map()
|
||||
# Make sure to register hooks for correct handling of saving/loading
|
||||
# the module's _sequence_map.
|
||||
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
|
||||
self._register_state_dict_hook(_save_sequence_map_hook)
|
||||
# the module's _key_map.
|
||||
self._register_load_state_dict_pre_hook(self._load_key_map_hook)
|
||||
self._register_state_dict_hook(_save_key_map_hook)
|
||||
|
||||
def _build_sequence_map(
|
||||
self, sequence_map_dict: Optional[Dict[str, int]] = None
|
||||
def _build_key_map(
|
||||
self, key_map_dict: Optional[Dict[str, int]] = None
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Args:
|
||||
sequence_map_dict: A dictionary used to initialize the sequence_map.
|
||||
key_map_dict: A dictionary used to initialize the key_map.
|
||||
|
||||
Returns:
|
||||
sequence_map: a dictionary of key: id pairs.
|
||||
key_map: a dictionary of key: id pairs.
|
||||
"""
|
||||
# increments the counter when asked for a new value
|
||||
sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
|
||||
if sequence_map_dict is not None:
|
||||
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
|
||||
key_map = defaultdict(iter(range(self.n_instances)).__next__)
|
||||
if key_map_dict is not None:
|
||||
# Assign all keys from the loaded key_map_dict to self._key_map.
|
||||
# Since this is done in the original order, it should generate
|
||||
# the same set of key:id pairs. We check this with an assert to be sure.
|
||||
for x, x_id in sequence_map_dict.items():
|
||||
x_id_ = sequence_map[x]
|
||||
for x, x_id in key_map_dict.items():
|
||||
x_id_ = key_map[x]
|
||||
assert x_id == x_id_
|
||||
return sequence_map
|
||||
return key_map
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
if self.n_instances <= 0:
|
||||
@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: A batch of `N` sequence identifiers. Either a long tensor of size
|
||||
x: A batch of `N` identifiers. Either a long tensor of size
|
||||
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
|
||||
are hashed to codes (without collisions).
|
||||
|
||||
Returns:
|
||||
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
||||
sequence-specific autodecoder codes.
|
||||
key-specific autodecoder codes.
|
||||
"""
|
||||
if self.n_instances == 0:
|
||||
return None
|
||||
@ -103,7 +102,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
# `Tensor`.
|
||||
x = torch.tensor(
|
||||
# pyre-ignore[29]
|
||||
[self._sequence_map[elem] for elem in x],
|
||||
[self._key_map[elem] for elem in x],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
@ -113,7 +112,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
return self._autodecoder_codes(x)
|
||||
|
||||
def _load_sequence_map_hook(
|
||||
def _load_key_map_hook(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
@ -142,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
|
||||
Returns:
|
||||
Constructed sequence_map if it exists in the state_dict
|
||||
Constructed key_map if it exists in the state_dict
|
||||
else raises a warning only.
|
||||
"""
|
||||
sequence_map_key = prefix + "_sequence_map"
|
||||
if sequence_map_key in state_dict:
|
||||
sequence_map_dict = state_dict.pop(sequence_map_key)
|
||||
self._sequence_map = self._build_sequence_map(
|
||||
sequence_map_dict=sequence_map_dict
|
||||
)
|
||||
key_map_key = prefix + "_key_map"
|
||||
if key_map_key in state_dict:
|
||||
key_map_dict = state_dict.pop(key_map_key)
|
||||
self._key_map = self._build_key_map(key_map_dict=key_map_dict)
|
||||
else:
|
||||
warnings.warn("No sequence map in Autodecoder state dict!")
|
||||
warnings.warn("No key map in Autodecoder state dict!")
|
||||
|
||||
|
||||
def _save_sequence_map_hook(
|
||||
def _save_key_map_hook(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
@ -169,6 +166,6 @@ def _save_sequence_map_hook(
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
"""
|
||||
sequence_map_key = prefix + "_sequence_map"
|
||||
sequence_map_dict = dict(self._sequence_map.items())
|
||||
state_dict[sequence_map_key] = sequence_map_dict
|
||||
key_map_key = prefix + "_key_map"
|
||||
key_map_dict = dict(self._key_map.items())
|
||||
state_dict[key_map_key] = key_map_dict
|
110
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
110
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
|
||||
from .autodecoder import Autodecoder
|
||||
|
||||
|
||||
class GlobalEncoderBase(ReplaceableBase):
|
||||
"""
|
||||
A base class for implementing encoders of global frame-specific quantities.
|
||||
|
||||
The latter includes e.g. the harmonic encoding of a frame timestamp
|
||||
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
|
||||
(`SequenceAutodecoder`).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_encoding_dim(self):
|
||||
"""
|
||||
Returns the dimensionality of the returned encoding.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
"""
|
||||
Calculates the squared norm of the encoding.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Given a set of inputs to encode, generates a tensor containing the encoding.
|
||||
|
||||
Returns:
|
||||
encoding: The tensor containing the global encoding.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# TODO: probabilistic embeddings?
|
||||
@registry.register
|
||||
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
|
||||
"""
|
||||
A global encoder implementation which provides an autodecoder encoding
|
||||
of the frame's sequence identifier.
|
||||
"""
|
||||
|
||||
autodecoder: Autodecoder
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_encoding_dim(self):
|
||||
return self.autodecoder.get_encoding_dim()
|
||||
|
||||
def forward(
|
||||
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
# run dtype checks and pass sequence_name to self.autodecoder
|
||||
return self.autodecoder(sequence_name)
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
return self.autodecoder.calc_squared_encoding_norm()
|
||||
|
||||
|
||||
@registry.register
|
||||
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||
"""
|
||||
A global encoder implementation which provides harmonic embeddings
|
||||
of each frame's timestamp.
|
||||
"""
|
||||
|
||||
n_harmonic_functions: int = 10
|
||||
append_input: bool = True
|
||||
time_divisor: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self._harmonic_embedding = HarmonicEmbedding(
|
||||
n_harmonic_functions=self.n_harmonic_functions,
|
||||
append_input=self.append_input,
|
||||
)
|
||||
|
||||
def get_encoding_dim(self):
|
||||
return self._harmonic_embedding.get_output_dim(1)
|
||||
|
||||
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
if frame_timestamp.shape[-1] != 1:
|
||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||
time = frame_timestamp / self.time_divisor
|
||||
return self._harmonic_embedding(time) # pyre-ignore: 29
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
return 0.0
|
@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
|
||||
n_train_target_views: 1
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
raysampler_class_type: AdaptiveRaySampler
|
||||
renderer_class_type: LSTMRenderer
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
@ -48,11 +49,12 @@ log_vars:
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
sequence_autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
|
@ -7,11 +7,13 @@
|
||||
import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
||||
ResNetFeatureExtractor,
|
||||
)
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
|
||||
SequenceAutodecoder,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
)
|
||||
@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
self.assertIsInstance(
|
||||
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
||||
)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsNone(gm.global_encoder)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
self.assertIsNone(gm.view_pooler)
|
||||
self.assertIsNone(gm.image_feature_extractor)
|
||||
@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
args.global_encoder_class_type = "SequenceAutodecoder"
|
||||
idr_args = args.implicit_function_IdrFeatureField_args
|
||||
idr_args.n_harmonic_functions_xyz = 1729
|
||||
|
||||
@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
|
||||
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
|
||||
@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
print(DATA_DIR)
|
||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||
|
Loading…
x
Reference in New Issue
Block a user