mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Refactor autodecoders
Summary: Refactors autodecoders. Tests pass. Reviewed By: bottler Differential Revision: D37592429 fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
parent
ae35824f21
commit
0dce883241
@ -11,10 +11,12 @@ generic_model_args:
|
|||||||
num_passes: 1
|
num_passes: 1
|
||||||
output_rasterized_mc: true
|
output_rasterized_mc: true
|
||||||
sampling_mode_training: mask_sample
|
sampling_mode_training: mask_sample
|
||||||
sequence_autodecoder_args:
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
n_instances: 20000
|
global_encoder_SequenceAutodecoder_args:
|
||||||
init_scale: 1.0
|
autodecoder_args:
|
||||||
encoding_dim: 256
|
n_instances: 20000
|
||||||
|
init_scale: 1.0
|
||||||
|
encoding_dim: 256
|
||||||
implicit_function_IdrFeatureField_args:
|
implicit_function_IdrFeatureField_args:
|
||||||
n_harmonic_functions_xyz: 6
|
n_harmonic_functions_xyz: 6
|
||||||
bias: 0.6
|
bias: 0.6
|
||||||
|
@ -4,6 +4,8 @@ defaults:
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: false
|
view_pooler_enabled: false
|
||||||
sequence_autodecoder_args:
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
n_instances: 20000
|
global_encoder_SequenceAutodecoder_args:
|
||||||
encoding_dim: 256
|
autodecoder_args:
|
||||||
|
n_instances: 20000
|
||||||
|
encoding_dim: 256
|
||||||
|
@ -13,9 +13,11 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.001
|
loss_autodecoder_norm: 0.001
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
sequence_autodecoder_args:
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
encoding_dim: 256
|
global_encoder_SequenceAutodecoder_args:
|
||||||
n_instances: 20000
|
autodecoder_args:
|
||||||
|
encoding_dim: 256
|
||||||
|
n_instances: 20000
|
||||||
raysampler_class_type: NearFarRaySampler
|
raysampler_class_type: NearFarRaySampler
|
||||||
raysampler_NearFarRaySampler_args:
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
|
@ -16,6 +16,7 @@ generic_model_args:
|
|||||||
n_train_target_views: 1
|
n_train_target_views: 1
|
||||||
sampling_mode_training: mask_sample
|
sampling_mode_training: mask_sample
|
||||||
sampling_mode_evaluation: full_grid
|
sampling_mode_evaluation: full_grid
|
||||||
|
global_encoder_class_type: null
|
||||||
raysampler_class_type: AdaptiveRaySampler
|
raysampler_class_type: AdaptiveRaySampler
|
||||||
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
||||||
image_feature_extractor_class_type: null
|
image_feature_extractor_class_type: null
|
||||||
@ -49,11 +50,16 @@ generic_model_args:
|
|||||||
- objective
|
- objective
|
||||||
- epoch
|
- epoch
|
||||||
- sec/it
|
- sec/it
|
||||||
sequence_autodecoder_args:
|
global_encoder_HarmonicTimeEncoder_args:
|
||||||
encoding_dim: 0
|
n_harmonic_functions: 10
|
||||||
n_instances: 0
|
append_input: true
|
||||||
init_scale: 1.0
|
time_divisor: 1.0
|
||||||
ignore_input: false
|
global_encoder_SequenceAutodecoder_args:
|
||||||
|
autodecoder_args:
|
||||||
|
encoding_dim: 0
|
||||||
|
n_instances: 0
|
||||||
|
init_scale: 1.0
|
||||||
|
ignore_input: false
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
image_width: 400
|
image_width: 400
|
||||||
image_height: 400
|
image_height: 400
|
||||||
|
@ -12,7 +12,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -34,10 +34,10 @@ from pytorch3d.renderer import RayBundle, utils as rend_utils
|
|||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
|
|
||||||
from .autodecoder import Autodecoder
|
|
||||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||||
from .feature_extractor import FeatureExtractorBase
|
from .feature_extractor import FeatureExtractorBase
|
||||||
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
|
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
|
||||||
|
from .global_encoder.global_encoder import GlobalEncoderBase
|
||||||
from .implicit_function.base import ImplicitFunctionBase
|
from .implicit_function.base import ImplicitFunctionBase
|
||||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||||
from .implicit_function.neural_radiance_field import ( # noqa
|
from .implicit_function.neural_radiance_field import ( # noqa
|
||||||
@ -63,7 +63,6 @@ from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
|||||||
from .view_pooler.view_pooler import ViewPooler
|
from .view_pooler.view_pooler import ViewPooler
|
||||||
|
|
||||||
|
|
||||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -109,6 +108,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
------------------
|
------------------
|
||||||
Evaluate the implicit function(s) at the sampled ray points
|
Evaluate the implicit function(s) at the sampled ray points
|
||||||
(optionally pass in the aggregated image features from (4)).
|
(optionally pass in the aggregated image features from (4)).
|
||||||
|
(also optionally pass in a global encoding from global_encoder).
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
(6) Rendering
|
(6) Rendering
|
||||||
@ -163,7 +163,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
sampling_mode_training: The sampling method to use during training. Must be
|
sampling_mode_training: The sampling method to use during training. Must be
|
||||||
a value from the RenderSamplingMode Enum.
|
a value from the RenderSamplingMode Enum.
|
||||||
sampling_mode_evaluation: Same as above but for evaluation.
|
sampling_mode_evaluation: Same as above but for evaluation.
|
||||||
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
|
global_encoder_class_type: The name of the class to use for global_encoder,
|
||||||
|
which must be available in the registry. Or `None` to disable global encoder.
|
||||||
|
global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
|
||||||
of the image (referred to as the global_code) that can be used to model aspects of
|
of the image (referred to as the global_code) that can be used to model aspects of
|
||||||
the scene such as multiple objects or morphing objects. It is up to the implicit
|
the scene such as multiple objects or morphing objects. It is up to the implicit
|
||||||
function definition how to use it, but the most typical way is to broadcast and
|
function definition how to use it, but the most typical way is to broadcast and
|
||||||
@ -221,8 +223,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
sampling_mode_training: str = "mask_sample"
|
sampling_mode_training: str = "mask_sample"
|
||||||
sampling_mode_evaluation: str = "full_grid"
|
sampling_mode_evaluation: str = "full_grid"
|
||||||
|
|
||||||
# ---- autodecoder settings
|
# ---- global encoder settings
|
||||||
sequence_autodecoder: Autodecoder
|
global_encoder_class_type: Optional[str] = None
|
||||||
|
global_encoder: Optional[GlobalEncoderBase]
|
||||||
|
|
||||||
# ---- raysampler
|
# ---- raysampler
|
||||||
raysampler_class_type: str = "AdaptiveRaySampler"
|
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||||
@ -284,7 +287,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
"loss_prev_stage_rgb_psnr_fg",
|
"loss_prev_stage_rgb_psnr_fg",
|
||||||
"loss_prev_stage_rgb_psnr",
|
"loss_prev_stage_rgb_psnr",
|
||||||
"loss_prev_stage_mask_bce",
|
"loss_prev_stage_mask_bce",
|
||||||
*STD_LOG_VARS,
|
# basic metrics
|
||||||
|
"objective",
|
||||||
|
"epoch",
|
||||||
|
"sec/it",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -307,10 +313,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
*, # force keyword-only arguments
|
*, # force keyword-only arguments
|
||||||
image_rgb: Optional[torch.Tensor],
|
image_rgb: Optional[torch.Tensor],
|
||||||
camera: CamerasBase,
|
camera: CamerasBase,
|
||||||
fg_probability: Optional[torch.Tensor],
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
mask_crop: Optional[torch.Tensor],
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
depth_map: Optional[torch.Tensor],
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
sequence_name: Optional[List[str]],
|
sequence_name: Optional[List[str]] = None,
|
||||||
|
frame_timestamp: Optional[torch.Tensor] = None,
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -333,6 +340,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
sequence_name: A list of `B` strings corresponding to the sequence names
|
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||||
from which images `image_rgb` were extracted. They are used to match
|
from which images `image_rgb` were extracted. They are used to match
|
||||||
target frames with relevant source frames.
|
target frames with relevant source frames.
|
||||||
|
frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
|
||||||
|
of frame timestamps.
|
||||||
evaluation_mode: one of EvaluationMode.TRAINING or
|
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||||
EvaluationMode.EVALUATION which determines the settings used for
|
EvaluationMode.EVALUATION which determines the settings used for
|
||||||
rendering.
|
rendering.
|
||||||
@ -357,6 +366,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
else min(self.n_train_target_views, batch_size)
|
else min(self.n_train_target_views, batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A helper function for selecting n_target first elements from the input
|
||||||
|
# where the latter can be None.
|
||||||
|
def _safe_slice_targets(
|
||||||
|
tensor: Optional[Union[torch.Tensor, List[str]]],
|
||||||
|
) -> Optional[Union[torch.Tensor, List[str]]]:
|
||||||
|
return None if tensor is None else tensor[:n_targets]
|
||||||
|
|
||||||
# Select the target cameras.
|
# Select the target cameras.
|
||||||
target_cameras = camera[list(range(n_targets))]
|
target_cameras = camera[list(range(n_targets))]
|
||||||
|
|
||||||
@ -405,10 +421,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
custom_args["fun_viewpool"] = curried_viewpooler
|
custom_args["fun_viewpool"] = curried_viewpooler
|
||||||
|
|
||||||
global_code = None
|
global_code = None
|
||||||
if self.sequence_autodecoder.n_instances > 0:
|
if self.global_encoder is not None:
|
||||||
if sequence_name is None:
|
global_code = self.global_encoder( # pyre-fixme[29]
|
||||||
raise ValueError("sequence_name must be provided for autodecoder.")
|
sequence_name=_safe_slice_targets(sequence_name),
|
||||||
global_code = self.sequence_autodecoder(sequence_name[:n_targets])
|
frame_timestamp=_safe_slice_targets(frame_timestamp),
|
||||||
|
)
|
||||||
custom_args["global_code"] = global_code
|
custom_args["global_code"] = global_code
|
||||||
|
|
||||||
# pyre-fixme[29]:
|
# pyre-fixme[29]:
|
||||||
@ -447,20 +464,15 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
# A dict to store losses as well as rendering results.
|
# A dict to store losses as well as rendering results.
|
||||||
preds: Dict[str, Any] = {}
|
preds: Dict[str, Any] = {}
|
||||||
|
|
||||||
def safe_slice_targets(
|
|
||||||
tensor: Optional[torch.Tensor],
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
return None if tensor is None else tensor[:n_targets]
|
|
||||||
|
|
||||||
preds.update(
|
preds.update(
|
||||||
self.view_metrics(
|
self.view_metrics(
|
||||||
results=preds,
|
results=preds,
|
||||||
raymarched=rendered,
|
raymarched=rendered,
|
||||||
xys=ray_bundle.xys,
|
xys=ray_bundle.xys,
|
||||||
image_rgb=safe_slice_targets(image_rgb),
|
image_rgb=_safe_slice_targets(image_rgb),
|
||||||
depth_map=safe_slice_targets(depth_map),
|
depth_map=_safe_slice_targets(depth_map),
|
||||||
fg_probability=safe_slice_targets(fg_probability),
|
fg_probability=_safe_slice_targets(fg_probability),
|
||||||
mask_crop=safe_slice_targets(mask_crop),
|
mask_crop=_safe_slice_targets(mask_crop),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -592,6 +604,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_global_encoder_encoding_dim(self) -> int:
|
||||||
|
if self.global_encoder is None:
|
||||||
|
return 0
|
||||||
|
return self.global_encoder.get_encoding_dim()
|
||||||
|
|
||||||
def _get_viewpooled_feature_dim(self) -> int:
|
def _get_viewpooled_feature_dim(self) -> int:
|
||||||
if self.view_pooler is None:
|
if self.view_pooler is None:
|
||||||
return 0
|
return 0
|
||||||
@ -668,8 +685,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
||||||
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
||||||
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
||||||
self._get_viewpooled_feature_dim()
|
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||||
+ self.sequence_autodecoder.get_encoding_dim()
|
|
||||||
)
|
)
|
||||||
nerf_args["color_dim"] = nerformer_args[
|
nerf_args["color_dim"] = nerformer_args[
|
||||||
"color_dim"
|
"color_dim"
|
||||||
@ -678,21 +694,18 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
# idr preprocessing
|
# idr preprocessing
|
||||||
idr = self.implicit_function_IdrFeatureField_args
|
idr = self.implicit_function_IdrFeatureField_args
|
||||||
idr["feature_vector_size"] = self.render_features_dimensions
|
idr["feature_vector_size"] = self.render_features_dimensions
|
||||||
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
|
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||||
|
|
||||||
# srn preprocessing
|
# srn preprocessing
|
||||||
srn = self.implicit_function_SRNImplicitFunction_args
|
srn = self.implicit_function_SRNImplicitFunction_args
|
||||||
srn.raymarch_function_args.latent_dim = (
|
srn.raymarch_function_args.latent_dim = (
|
||||||
self._get_viewpooled_feature_dim()
|
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||||
+ self.sequence_autodecoder.get_encoding_dim()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# srn_hypernet preprocessing
|
# srn_hypernet preprocessing
|
||||||
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
||||||
srn_hypernet_args = srn_hypernet.hypernet_args
|
srn_hypernet_args = srn_hypernet.hypernet_args
|
||||||
srn_hypernet_args.latent_dim_hypernet = (
|
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
|
||||||
self.sequence_autodecoder.get_encoding_dim()
|
|
||||||
)
|
|
||||||
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
||||||
|
|
||||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||||
|
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
@ -12,10 +12,9 @@ import torch
|
|||||||
from pytorch3d.implicitron.tools.config import Configurable
|
from pytorch3d.implicitron.tools.config import Configurable
|
||||||
|
|
||||||
|
|
||||||
# TODO: probabilistic embeddings?
|
|
||||||
class Autodecoder(Configurable, torch.nn.Module):
|
class Autodecoder(Configurable, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Autodecoder module
|
Autodecoder which maps a list of integer or string keys to optimizable embeddings.
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
encoding_dim: Embedding dimension for the decoder.
|
encoding_dim: Embedding dimension for the decoder.
|
||||||
@ -43,32 +42,32 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
# weight has been initialised from Normal(0, 1)
|
# weight has been initialised from Normal(0, 1)
|
||||||
self._autodecoder_codes.weight *= self.init_scale
|
self._autodecoder_codes.weight *= self.init_scale
|
||||||
|
|
||||||
self._sequence_map = self._build_sequence_map()
|
self._key_map = self._build_key_map()
|
||||||
# Make sure to register hooks for correct handling of saving/loading
|
# Make sure to register hooks for correct handling of saving/loading
|
||||||
# the module's _sequence_map.
|
# the module's _key_map.
|
||||||
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
|
self._register_load_state_dict_pre_hook(self._load_key_map_hook)
|
||||||
self._register_state_dict_hook(_save_sequence_map_hook)
|
self._register_state_dict_hook(_save_key_map_hook)
|
||||||
|
|
||||||
def _build_sequence_map(
|
def _build_key_map(
|
||||||
self, sequence_map_dict: Optional[Dict[str, int]] = None
|
self, key_map_dict: Optional[Dict[str, int]] = None
|
||||||
) -> Dict[str, int]:
|
) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
sequence_map_dict: A dictionary used to initialize the sequence_map.
|
key_map_dict: A dictionary used to initialize the key_map.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sequence_map: a dictionary of key: id pairs.
|
key_map: a dictionary of key: id pairs.
|
||||||
"""
|
"""
|
||||||
# increments the counter when asked for a new value
|
# increments the counter when asked for a new value
|
||||||
sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
|
key_map = defaultdict(iter(range(self.n_instances)).__next__)
|
||||||
if sequence_map_dict is not None:
|
if key_map_dict is not None:
|
||||||
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
|
# Assign all keys from the loaded key_map_dict to self._key_map.
|
||||||
# Since this is done in the original order, it should generate
|
# Since this is done in the original order, it should generate
|
||||||
# the same set of key:id pairs. We check this with an assert to be sure.
|
# the same set of key:id pairs. We check this with an assert to be sure.
|
||||||
for x, x_id in sequence_map_dict.items():
|
for x, x_id in key_map_dict.items():
|
||||||
x_id_ = sequence_map[x]
|
x_id_ = key_map[x]
|
||||||
assert x_id == x_id_
|
assert x_id == x_id_
|
||||||
return sequence_map
|
return key_map
|
||||||
|
|
||||||
def calc_squared_encoding_norm(self):
|
def calc_squared_encoding_norm(self):
|
||||||
if self.n_instances <= 0:
|
if self.n_instances <= 0:
|
||||||
@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: A batch of `N` sequence identifiers. Either a long tensor of size
|
x: A batch of `N` identifiers. Either a long tensor of size
|
||||||
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
|
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
|
||||||
are hashed to codes (without collisions).
|
are hashed to codes (without collisions).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
||||||
sequence-specific autodecoder codes.
|
key-specific autodecoder codes.
|
||||||
"""
|
"""
|
||||||
if self.n_instances == 0:
|
if self.n_instances == 0:
|
||||||
return None
|
return None
|
||||||
@ -103,7 +102,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
# `Tensor`.
|
# `Tensor`.
|
||||||
x = torch.tensor(
|
x = torch.tensor(
|
||||||
# pyre-ignore[29]
|
# pyre-ignore[29]
|
||||||
[self._sequence_map[elem] for elem in x],
|
[self._key_map[elem] for elem in x],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
@ -113,7 +112,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
return self._autodecoder_codes(x)
|
return self._autodecoder_codes(x)
|
||||||
|
|
||||||
def _load_sequence_map_hook(
|
def _load_key_map_hook(
|
||||||
self,
|
self,
|
||||||
state_dict,
|
state_dict,
|
||||||
prefix,
|
prefix,
|
||||||
@ -142,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
:meth:`~torch.nn.Module.load_state_dict`
|
:meth:`~torch.nn.Module.load_state_dict`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Constructed sequence_map if it exists in the state_dict
|
Constructed key_map if it exists in the state_dict
|
||||||
else raises a warning only.
|
else raises a warning only.
|
||||||
"""
|
"""
|
||||||
sequence_map_key = prefix + "_sequence_map"
|
key_map_key = prefix + "_key_map"
|
||||||
if sequence_map_key in state_dict:
|
if key_map_key in state_dict:
|
||||||
sequence_map_dict = state_dict.pop(sequence_map_key)
|
key_map_dict = state_dict.pop(key_map_key)
|
||||||
self._sequence_map = self._build_sequence_map(
|
self._key_map = self._build_key_map(key_map_dict=key_map_dict)
|
||||||
sequence_map_dict=sequence_map_dict
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
warnings.warn("No sequence map in Autodecoder state dict!")
|
warnings.warn("No key map in Autodecoder state dict!")
|
||||||
|
|
||||||
|
|
||||||
def _save_sequence_map_hook(
|
def _save_key_map_hook(
|
||||||
self,
|
self,
|
||||||
state_dict,
|
state_dict,
|
||||||
prefix,
|
prefix,
|
||||||
@ -169,6 +166,6 @@ def _save_sequence_map_hook(
|
|||||||
module
|
module
|
||||||
local_metadata (dict): a dict containing the metadata for this module.
|
local_metadata (dict): a dict containing the metadata for this module.
|
||||||
"""
|
"""
|
||||||
sequence_map_key = prefix + "_sequence_map"
|
key_map_key = prefix + "_key_map"
|
||||||
sequence_map_dict = dict(self._sequence_map.items())
|
key_map_dict = dict(self._key_map.items())
|
||||||
state_dict[sequence_map_key] = sequence_map_dict
|
state_dict[key_map_key] = key_map_dict
|
110
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
110
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
registry,
|
||||||
|
ReplaceableBase,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
|
from .autodecoder import Autodecoder
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalEncoderBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
A base class for implementing encoders of global frame-specific quantities.
|
||||||
|
|
||||||
|
The latter includes e.g. the harmonic encoding of a frame timestamp
|
||||||
|
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
|
||||||
|
(`SequenceAutodecoder`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_encoding_dim(self):
|
||||||
|
"""
|
||||||
|
Returns the dimensionality of the returned encoding.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def calc_squared_encoding_norm(self):
|
||||||
|
"""
|
||||||
|
Calculates the squared norm of the encoding.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward(self, **kwargs) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Given a set of inputs to encode, generates a tensor containing the encoding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
encoding: The tensor containing the global encoding.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: probabilistic embeddings?
|
||||||
|
@registry.register
|
||||||
|
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
|
||||||
|
"""
|
||||||
|
A global encoder implementation which provides an autodecoder encoding
|
||||||
|
of the frame's sequence identifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
autodecoder: Autodecoder
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def get_encoding_dim(self):
|
||||||
|
return self.autodecoder.get_encoding_dim()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# run dtype checks and pass sequence_name to self.autodecoder
|
||||||
|
return self.autodecoder(sequence_name)
|
||||||
|
|
||||||
|
def calc_squared_encoding_norm(self):
|
||||||
|
return self.autodecoder.calc_squared_encoding_norm()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A global encoder implementation which provides harmonic embeddings
|
||||||
|
of each frame's timestamp.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_harmonic_functions: int = 10
|
||||||
|
append_input: bool = True
|
||||||
|
time_divisor: float = 1.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._harmonic_embedding = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=self.n_harmonic_functions,
|
||||||
|
append_input=self.append_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_encoding_dim(self):
|
||||||
|
return self._harmonic_embedding.get_output_dim(1)
|
||||||
|
|
||||||
|
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
if frame_timestamp.shape[-1] != 1:
|
||||||
|
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||||
|
time = frame_timestamp / self.time_divisor
|
||||||
|
return self._harmonic_embedding(time) # pyre-ignore: 29
|
||||||
|
|
||||||
|
def calc_squared_encoding_norm(self):
|
||||||
|
return 0.0
|
@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
|
|||||||
n_train_target_views: 1
|
n_train_target_views: 1
|
||||||
sampling_mode_training: mask_sample
|
sampling_mode_training: mask_sample
|
||||||
sampling_mode_evaluation: full_grid
|
sampling_mode_evaluation: full_grid
|
||||||
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
raysampler_class_type: AdaptiveRaySampler
|
raysampler_class_type: AdaptiveRaySampler
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
@ -48,11 +49,12 @@ log_vars:
|
|||||||
- objective
|
- objective
|
||||||
- epoch
|
- epoch
|
||||||
- sec/it
|
- sec/it
|
||||||
sequence_autodecoder_args:
|
global_encoder_SequenceAutodecoder_args:
|
||||||
encoding_dim: 0
|
autodecoder_args:
|
||||||
n_instances: 0
|
encoding_dim: 0
|
||||||
init_scale: 1.0
|
n_instances: 0
|
||||||
ignore_input: false
|
init_scale: 1.0
|
||||||
|
ignore_input: false
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
image_width: 400
|
image_width: 400
|
||||||
image_height: 400
|
image_height: 400
|
||||||
|
@ -7,11 +7,13 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
|
||||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
||||||
ResNetFeatureExtractor,
|
ResNetFeatureExtractor,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
|
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
|
||||||
|
SequenceAutodecoder,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||||
IdrFeatureField,
|
IdrFeatureField,
|
||||||
)
|
)
|
||||||
@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
||||||
)
|
)
|
||||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
self.assertIsNone(gm.global_encoder)
|
||||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||||
self.assertIsNone(gm.view_pooler)
|
self.assertIsNone(gm.view_pooler)
|
||||||
self.assertIsNone(gm.image_feature_extractor)
|
self.assertIsNone(gm.image_feature_extractor)
|
||||||
@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
||||||
args.implicit_function_class_type = "IdrFeatureField"
|
args.implicit_function_class_type = "IdrFeatureField"
|
||||||
|
args.global_encoder_class_type = "SequenceAutodecoder"
|
||||||
idr_args = args.implicit_function_IdrFeatureField_args
|
idr_args = args.implicit_function_IdrFeatureField_args
|
||||||
idr_args.n_harmonic_functions_xyz = 1729
|
idr_args.n_harmonic_functions_xyz = 1729
|
||||||
|
|
||||||
@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||||
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
|
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
|
||||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
|
||||||
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||||
|
|
||||||
@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
remove_unused_components(instance_args)
|
remove_unused_components(instance_args)
|
||||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
|
print(DATA_DIR)
|
||||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user