Refactor autodecoders

Summary: Refactors autodecoders. Tests pass.

Reviewed By: bottler

Differential Revision: D37592429

fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
David Novotny 2022-07-04 07:18:03 -07:00 committed by Facebook GitHub Bot
parent ae35824f21
commit 0dce883241
10 changed files with 230 additions and 87 deletions

View File

@ -11,10 +11,12 @@ generic_model_args:
num_passes: 1
output_rasterized_mc: true
sampling_mode_training: mask_sample
sequence_autodecoder_args:
n_instances: 20000
init_scale: 1.0
encoding_dim: 256
global_encoder_class_type: SequenceAutodecoder
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
n_instances: 20000
init_scale: 1.0
encoding_dim: 256
implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6
bias: 0.6

View File

@ -4,6 +4,8 @@ defaults:
generic_model_args:
chunk_size_grid: 16000
view_pooler_enabled: false
sequence_autodecoder_args:
n_instances: 20000
encoding_dim: 256
global_encoder_class_type: SequenceAutodecoder
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
n_instances: 20000
encoding_dim: 256

View File

@ -13,9 +13,11 @@ generic_model_args:
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.001
depth_neg_penalty: 10000.0
sequence_autodecoder_args:
encoding_dim: 256
n_instances: 20000
global_encoder_class_type: SequenceAutodecoder
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 256
n_instances: 20000
raysampler_class_type: NearFarRaySampler
raysampler_NearFarRaySampler_args:
n_rays_per_image_sampled_from_mask: 2048

View File

@ -16,6 +16,7 @@ generic_model_args:
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
global_encoder_class_type: null
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: MultiPassEmissionAbsorptionRenderer
image_feature_extractor_class_type: null
@ -49,11 +50,16 @@ generic_model_args:
- objective
- epoch
- sec/it
sequence_autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
global_encoder_HarmonicTimeEncoder_args:
n_harmonic_functions: 10
append_input: true
time_divisor: 1.0
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400

View File

@ -12,7 +12,7 @@ import logging
import math
import warnings
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import tqdm
@ -34,10 +34,10 @@ from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom
from .autodecoder import Autodecoder
from .base_model import ImplicitronModelBase, ImplicitronRender
from .feature_extractor import FeatureExtractorBase
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
from .global_encoder.global_encoder import GlobalEncoderBase
from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa
@ -63,7 +63,6 @@ from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .view_pooler.view_pooler import ViewPooler
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
logger = logging.getLogger(__name__)
@ -109,6 +108,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
------------------
Evaluate the implicit function(s) at the sampled ray points
(optionally pass in the aggregated image features from (4)).
(also optionally pass in a global encoding from global_encoder).
(6) Rendering
@ -163,7 +163,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
sampling_mode_training: The sampling method to use during training. Must be
a value from the RenderSamplingMode Enum.
sampling_mode_evaluation: Same as above but for evaluation.
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
global_encoder_class_type: The name of the class to use for global_encoder,
which must be available in the registry. Or `None` to disable global encoder.
global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
of the image (referred to as the global_code) that can be used to model aspects of
the scene such as multiple objects or morphing objects. It is up to the implicit
function definition how to use it, but the most typical way is to broadcast and
@ -221,8 +223,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid"
# ---- autodecoder settings
sequence_autodecoder: Autodecoder
# ---- global encoder settings
global_encoder_class_type: Optional[str] = None
global_encoder: Optional[GlobalEncoderBase]
# ---- raysampler
raysampler_class_type: str = "AdaptiveRaySampler"
@ -284,7 +287,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
"loss_prev_stage_rgb_psnr_fg",
"loss_prev_stage_rgb_psnr",
"loss_prev_stage_mask_bce",
*STD_LOG_VARS,
# basic metrics
"objective",
"epoch",
"sec/it",
]
)
@ -307,10 +313,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
*, # force keyword-only arguments
image_rgb: Optional[torch.Tensor],
camera: CamerasBase,
fg_probability: Optional[torch.Tensor],
mask_crop: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
sequence_name: Optional[List[str]],
fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
sequence_name: Optional[List[str]] = None,
frame_timestamp: Optional[torch.Tensor] = None,
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> Dict[str, Any]:
@ -333,6 +340,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
sequence_name: A list of `B` strings corresponding to the sequence names
from which images `image_rgb` were extracted. They are used to match
target frames with relevant source frames.
frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
of frame timestamps.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
@ -357,6 +366,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
else min(self.n_train_target_views, batch_size)
)
# A helper function for selecting n_target first elements from the input
# where the latter can be None.
def _safe_slice_targets(
tensor: Optional[Union[torch.Tensor, List[str]]],
) -> Optional[Union[torch.Tensor, List[str]]]:
return None if tensor is None else tensor[:n_targets]
# Select the target cameras.
target_cameras = camera[list(range(n_targets))]
@ -405,10 +421,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
custom_args["fun_viewpool"] = curried_viewpooler
global_code = None
if self.sequence_autodecoder.n_instances > 0:
if sequence_name is None:
raise ValueError("sequence_name must be provided for autodecoder.")
global_code = self.sequence_autodecoder(sequence_name[:n_targets])
if self.global_encoder is not None:
global_code = self.global_encoder( # pyre-fixme[29]
sequence_name=_safe_slice_targets(sequence_name),
frame_timestamp=_safe_slice_targets(frame_timestamp),
)
custom_args["global_code"] = global_code
# pyre-fixme[29]:
@ -447,20 +464,15 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# A dict to store losses as well as rendering results.
preds: Dict[str, Any] = {}
def safe_slice_targets(
tensor: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
return None if tensor is None else tensor[:n_targets]
preds.update(
self.view_metrics(
results=preds,
raymarched=rendered,
xys=ray_bundle.xys,
image_rgb=safe_slice_targets(image_rgb),
depth_map=safe_slice_targets(depth_map),
fg_probability=safe_slice_targets(fg_probability),
mask_crop=safe_slice_targets(mask_crop),
image_rgb=_safe_slice_targets(image_rgb),
depth_map=_safe_slice_targets(depth_map),
fg_probability=_safe_slice_targets(fg_probability),
mask_crop=_safe_slice_targets(mask_crop),
)
)
@ -592,6 +604,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
**kwargs,
)
def _get_global_encoder_encoding_dim(self) -> int:
if self.global_encoder is None:
return 0
return self.global_encoder.get_encoding_dim()
def _get_viewpooled_feature_dim(self) -> int:
if self.view_pooler is None:
return 0
@ -668,8 +685,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
self._get_viewpooled_feature_dim()
+ self.sequence_autodecoder.get_encoding_dim()
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
)
nerf_args["color_dim"] = nerformer_args[
"color_dim"
@ -678,21 +694,18 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# idr preprocessing
idr = self.implicit_function_IdrFeatureField_args
idr["feature_vector_size"] = self.render_features_dimensions
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
# srn preprocessing
srn = self.implicit_function_SRNImplicitFunction_args
srn.raymarch_function_args.latent_dim = (
self._get_viewpooled_feature_dim()
+ self.sequence_autodecoder.get_encoding_dim()
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
)
# srn_hypernet preprocessing
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
srn_hypernet_args = srn_hypernet.hypernet_args
srn_hypernet_args.latent_dim_hypernet = (
self.sequence_autodecoder.get_encoding_dim()
)
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
# check that for srn, srn_hypernet, idr we have self.num_passes=1

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -12,10 +12,9 @@ import torch
from pytorch3d.implicitron.tools.config import Configurable
# TODO: probabilistic embeddings?
class Autodecoder(Configurable, torch.nn.Module):
"""
Autodecoder module
Autodecoder which maps a list of integer or string keys to optimizable embeddings.
Settings:
encoding_dim: Embedding dimension for the decoder.
@ -43,32 +42,32 @@ class Autodecoder(Configurable, torch.nn.Module):
# weight has been initialised from Normal(0, 1)
self._autodecoder_codes.weight *= self.init_scale
self._sequence_map = self._build_sequence_map()
self._key_map = self._build_key_map()
# Make sure to register hooks for correct handling of saving/loading
# the module's _sequence_map.
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
self._register_state_dict_hook(_save_sequence_map_hook)
# the module's _key_map.
self._register_load_state_dict_pre_hook(self._load_key_map_hook)
self._register_state_dict_hook(_save_key_map_hook)
def _build_sequence_map(
self, sequence_map_dict: Optional[Dict[str, int]] = None
def _build_key_map(
self, key_map_dict: Optional[Dict[str, int]] = None
) -> Dict[str, int]:
"""
Args:
sequence_map_dict: A dictionary used to initialize the sequence_map.
key_map_dict: A dictionary used to initialize the key_map.
Returns:
sequence_map: a dictionary of key: id pairs.
key_map: a dictionary of key: id pairs.
"""
# increments the counter when asked for a new value
sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
if sequence_map_dict is not None:
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
key_map = defaultdict(iter(range(self.n_instances)).__next__)
if key_map_dict is not None:
# Assign all keys from the loaded key_map_dict to self._key_map.
# Since this is done in the original order, it should generate
# the same set of key:id pairs. We check this with an assert to be sure.
for x, x_id in sequence_map_dict.items():
x_id_ = sequence_map[x]
for x, x_id in key_map_dict.items():
x_id_ = key_map[x]
assert x_id == x_id_
return sequence_map
return key_map
def calc_squared_encoding_norm(self):
if self.n_instances <= 0:
@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module):
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
"""
Args:
x: A batch of `N` sequence identifiers. Either a long tensor of size
x: A batch of `N` identifiers. Either a long tensor of size
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
are hashed to codes (without collisions).
Returns:
codes: A tensor of shape `(N, self.encoding_dim)` containing the
sequence-specific autodecoder codes.
key-specific autodecoder codes.
"""
if self.n_instances == 0:
return None
@ -103,7 +102,7 @@ class Autodecoder(Configurable, torch.nn.Module):
# `Tensor`.
x = torch.tensor(
# pyre-ignore[29]
[self._sequence_map[elem] for elem in x],
[self._key_map[elem] for elem in x],
dtype=torch.long,
device=next(self.parameters()).device,
)
@ -113,7 +112,7 @@ class Autodecoder(Configurable, torch.nn.Module):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self._autodecoder_codes(x)
def _load_sequence_map_hook(
def _load_key_map_hook(
self,
state_dict,
prefix,
@ -142,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module):
:meth:`~torch.nn.Module.load_state_dict`
Returns:
Constructed sequence_map if it exists in the state_dict
Constructed key_map if it exists in the state_dict
else raises a warning only.
"""
sequence_map_key = prefix + "_sequence_map"
if sequence_map_key in state_dict:
sequence_map_dict = state_dict.pop(sequence_map_key)
self._sequence_map = self._build_sequence_map(
sequence_map_dict=sequence_map_dict
)
key_map_key = prefix + "_key_map"
if key_map_key in state_dict:
key_map_dict = state_dict.pop(key_map_key)
self._key_map = self._build_key_map(key_map_dict=key_map_dict)
else:
warnings.warn("No sequence map in Autodecoder state dict!")
warnings.warn("No key map in Autodecoder state dict!")
def _save_sequence_map_hook(
def _save_key_map_hook(
self,
state_dict,
prefix,
@ -169,6 +166,6 @@ def _save_sequence_map_hook(
module
local_metadata (dict): a dict containing the metadata for this module.
"""
sequence_map_key = prefix + "_sequence_map"
sequence_map_dict = dict(self._sequence_map.items())
state_dict[sequence_map_key] = sequence_map_dict
key_map_key = prefix + "_key_map"
key_map_dict = dict(self._key_map.items())
state_dict[key_map_key] = key_map_dict

View File

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Union
import torch
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .autodecoder import Autodecoder
class GlobalEncoderBase(ReplaceableBase):
"""
A base class for implementing encoders of global frame-specific quantities.
The latter includes e.g. the harmonic encoding of a frame timestamp
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
(`SequenceAutodecoder`).
"""
def __init__(self) -> None:
super().__init__()
def get_encoding_dim(self):
"""
Returns the dimensionality of the returned encoding.
"""
raise NotImplementedError()
def calc_squared_encoding_norm(self):
"""
Calculates the squared norm of the encoding.
"""
raise NotImplementedError()
def forward(self, **kwargs) -> torch.Tensor:
"""
Given a set of inputs to encode, generates a tensor containing the encoding.
Returns:
encoding: The tensor containing the global encoding.
"""
raise NotImplementedError()
# TODO: probabilistic embeddings?
@registry.register
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
"""
A global encoder implementation which provides an autodecoder encoding
of the frame's sequence identifier.
"""
autodecoder: Autodecoder
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def get_encoding_dim(self):
return self.autodecoder.get_encoding_dim()
def forward(
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
) -> torch.Tensor:
# run dtype checks and pass sequence_name to self.autodecoder
return self.autodecoder(sequence_name)
def calc_squared_encoding_norm(self):
return self.autodecoder.calc_squared_encoding_norm()
@registry.register
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
"""
A global encoder implementation which provides harmonic embeddings
of each frame's timestamp.
"""
n_harmonic_functions: int = 10
append_input: bool = True
time_divisor: float = 1.0
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input,
)
def get_encoding_dim(self):
return self._harmonic_embedding.get_output_dim(1)
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor
return self._harmonic_embedding(time) # pyre-ignore: 29
def calc_squared_encoding_norm(self):
return 0.0

View File

@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
global_encoder_class_type: SequenceAutodecoder
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: LSTMRenderer
image_feature_extractor_class_type: ResNetFeatureExtractor
@ -48,11 +49,12 @@ log_vars:
- objective
- epoch
- sec/it
sequence_autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400

View File

@ -7,11 +7,13 @@
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
SequenceAutodecoder,
)
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)
@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsNone(gm.global_encoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor)
@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
)
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
args.implicit_function_class_type = "IdrFeatureField"
args.global_encoder_class_type = "SequenceAutodecoder"
idr_args = args.implicit_function_IdrFeatureField_args
idr_args.n_harmonic_functions_xyz = 1729
@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function"))
@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
print(DATA_DIR)
(DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())