diff --git a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml index d0181eec..183abae8 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml @@ -11,10 +11,12 @@ generic_model_args: num_passes: 1 output_rasterized_mc: true sampling_mode_training: mask_sample - sequence_autodecoder_args: - n_instances: 20000 - init_scale: 1.0 - encoding_dim: 256 + global_encoder_class_type: SequenceAutodecoder + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + n_instances: 20000 + init_scale: 1.0 + encoding_dim: 256 implicit_function_IdrFeatureField_args: n_harmonic_functions_xyz: 6 bias: 0.6 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml index f9a978f5..a8b99df8 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml @@ -4,6 +4,8 @@ defaults: generic_model_args: chunk_size_grid: 16000 view_pooler_enabled: false - sequence_autodecoder_args: - n_instances: 20000 - encoding_dim: 256 + global_encoder_class_type: SequenceAutodecoder + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + n_instances: 20000 + encoding_dim: 256 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml index 8c8ef0d7..8d88f736 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml @@ -13,9 +13,11 @@ generic_model_args: loss_prev_stage_mask_bce: 0.0 loss_autodecoder_norm: 0.001 depth_neg_penalty: 10000.0 - sequence_autodecoder_args: - encoding_dim: 256 - n_instances: 20000 + global_encoder_class_type: SequenceAutodecoder + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + encoding_dim: 256 + n_instances: 20000 raysampler_class_type: NearFarRaySampler raysampler_NearFarRaySampler_args: n_rays_per_image_sampled_from_mask: 2048 diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 1bdc0ddf..3eefa12e 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -16,6 +16,7 @@ generic_model_args: n_train_target_views: 1 sampling_mode_training: mask_sample sampling_mode_evaluation: full_grid + global_encoder_class_type: null raysampler_class_type: AdaptiveRaySampler renderer_class_type: MultiPassEmissionAbsorptionRenderer image_feature_extractor_class_type: null @@ -49,11 +50,16 @@ generic_model_args: - objective - epoch - sec/it - sequence_autodecoder_args: - encoding_dim: 0 - n_instances: 0 - init_scale: 1.0 - ignore_input: false + global_encoder_HarmonicTimeEncoder_args: + n_harmonic_functions: 10 + append_input: true + time_divisor: 1.0 + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + encoding_dim: 0 + n_instances: 0 + init_scale: 1.0 + ignore_input: false raysampler_AdaptiveRaySampler_args: image_width: 400 image_height: 400 diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index e7ee80df..78718e90 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -12,7 +12,7 @@ import logging import math import warnings from dataclasses import field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch import tqdm @@ -34,10 +34,10 @@ from pytorch3d.renderer import RayBundle, utils as rend_utils from pytorch3d.renderer.cameras import CamerasBase from visdom import Visdom -from .autodecoder import Autodecoder from .base_model import ImplicitronModelBase, ImplicitronRender from .feature_extractor import FeatureExtractorBase from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa +from .global_encoder.global_encoder import GlobalEncoderBase from .implicit_function.base import ImplicitFunctionBase from .implicit_function.idr_feature_field import IdrFeatureField # noqa from .implicit_function.neural_radiance_field import ( # noqa @@ -63,7 +63,6 @@ from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa from .view_pooler.view_pooler import ViewPooler -STD_LOG_VARS = ["objective", "epoch", "sec/it"] logger = logging.getLogger(__name__) @@ -109,6 +108,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ------------------ Evaluate the implicit function(s) at the sampled ray points (optionally pass in the aggregated image features from (4)). + (also optionally pass in a global encoding from global_encoder). │ ▼ (6) Rendering @@ -163,7 +163,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 sampling_mode_training: The sampling method to use during training. Must be a value from the RenderSamplingMode Enum. sampling_mode_evaluation: Same as above but for evaluation. - sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding + global_encoder_class_type: The name of the class to use for global_encoder, + which must be available in the registry. Or `None` to disable global encoder. + global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding of the image (referred to as the global_code) that can be used to model aspects of the scene such as multiple objects or morphing objects. It is up to the implicit function definition how to use it, but the most typical way is to broadcast and @@ -221,8 +223,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 sampling_mode_training: str = "mask_sample" sampling_mode_evaluation: str = "full_grid" - # ---- autodecoder settings - sequence_autodecoder: Autodecoder + # ---- global encoder settings + global_encoder_class_type: Optional[str] = None + global_encoder: Optional[GlobalEncoderBase] # ---- raysampler raysampler_class_type: str = "AdaptiveRaySampler" @@ -284,7 +287,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 "loss_prev_stage_rgb_psnr_fg", "loss_prev_stage_rgb_psnr", "loss_prev_stage_mask_bce", - *STD_LOG_VARS, + # basic metrics + "objective", + "epoch", + "sec/it", ] ) @@ -307,10 +313,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 *, # force keyword-only arguments image_rgb: Optional[torch.Tensor], camera: CamerasBase, - fg_probability: Optional[torch.Tensor], - mask_crop: Optional[torch.Tensor], - depth_map: Optional[torch.Tensor], - sequence_name: Optional[List[str]], + fg_probability: Optional[torch.Tensor] = None, + mask_crop: Optional[torch.Tensor] = None, + depth_map: Optional[torch.Tensor] = None, + sequence_name: Optional[List[str]] = None, + frame_timestamp: Optional[torch.Tensor] = None, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, ) -> Dict[str, Any]: @@ -333,6 +340,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 sequence_name: A list of `B` strings corresponding to the sequence names from which images `image_rgb` were extracted. They are used to match target frames with relevant source frames. + frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch + of frame timestamps. evaluation_mode: one of EvaluationMode.TRAINING or EvaluationMode.EVALUATION which determines the settings used for rendering. @@ -357,6 +366,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 else min(self.n_train_target_views, batch_size) ) + # A helper function for selecting n_target first elements from the input + # where the latter can be None. + def _safe_slice_targets( + tensor: Optional[Union[torch.Tensor, List[str]]], + ) -> Optional[Union[torch.Tensor, List[str]]]: + return None if tensor is None else tensor[:n_targets] + # Select the target cameras. target_cameras = camera[list(range(n_targets))] @@ -405,10 +421,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 custom_args["fun_viewpool"] = curried_viewpooler global_code = None - if self.sequence_autodecoder.n_instances > 0: - if sequence_name is None: - raise ValueError("sequence_name must be provided for autodecoder.") - global_code = self.sequence_autodecoder(sequence_name[:n_targets]) + if self.global_encoder is not None: + global_code = self.global_encoder( # pyre-fixme[29] + sequence_name=_safe_slice_targets(sequence_name), + frame_timestamp=_safe_slice_targets(frame_timestamp), + ) custom_args["global_code"] = global_code # pyre-fixme[29]: @@ -447,20 +464,15 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 # A dict to store losses as well as rendering results. preds: Dict[str, Any] = {} - def safe_slice_targets( - tensor: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: - return None if tensor is None else tensor[:n_targets] - preds.update( self.view_metrics( results=preds, raymarched=rendered, xys=ray_bundle.xys, - image_rgb=safe_slice_targets(image_rgb), - depth_map=safe_slice_targets(depth_map), - fg_probability=safe_slice_targets(fg_probability), - mask_crop=safe_slice_targets(mask_crop), + image_rgb=_safe_slice_targets(image_rgb), + depth_map=_safe_slice_targets(depth_map), + fg_probability=_safe_slice_targets(fg_probability), + mask_crop=_safe_slice_targets(mask_crop), ) ) @@ -592,6 +604,11 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 **kwargs, ) + def _get_global_encoder_encoding_dim(self) -> int: + if self.global_encoder is None: + return 0 + return self.global_encoder.get_encoding_dim() + def _get_viewpooled_feature_dim(self) -> int: if self.view_pooler is None: return 0 @@ -668,8 +685,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args nerf_args["latent_dim"] = nerformer_args["latent_dim"] = ( - self._get_viewpooled_feature_dim() - + self.sequence_autodecoder.get_encoding_dim() + self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim() ) nerf_args["color_dim"] = nerformer_args[ "color_dim" @@ -678,21 +694,18 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 # idr preprocessing idr = self.implicit_function_IdrFeatureField_args idr["feature_vector_size"] = self.render_features_dimensions - idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim() + idr["encoding_dim"] = self._get_global_encoder_encoding_dim() # srn preprocessing srn = self.implicit_function_SRNImplicitFunction_args srn.raymarch_function_args.latent_dim = ( - self._get_viewpooled_feature_dim() - + self.sequence_autodecoder.get_encoding_dim() + self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim() ) # srn_hypernet preprocessing srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args srn_hypernet_args = srn_hypernet.hypernet_args - srn_hypernet_args.latent_dim_hypernet = ( - self.sequence_autodecoder.get_encoding_dim() - ) + srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim() srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim() # check that for srn, srn_hypernet, idr we have self.num_passes=1 diff --git a/pytorch3d/implicitron/models/global_encoder/__init__.py b/pytorch3d/implicitron/models/global_encoder/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/pytorch3d/implicitron/models/global_encoder/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pytorch3d/implicitron/models/autodecoder.py b/pytorch3d/implicitron/models/global_encoder/autodecoder.py similarity index 73% rename from pytorch3d/implicitron/models/autodecoder.py rename to pytorch3d/implicitron/models/global_encoder/autodecoder.py index f7910262..3089ac68 100644 --- a/pytorch3d/implicitron/models/autodecoder.py +++ b/pytorch3d/implicitron/models/global_encoder/autodecoder.py @@ -12,10 +12,9 @@ import torch from pytorch3d.implicitron.tools.config import Configurable -# TODO: probabilistic embeddings? class Autodecoder(Configurable, torch.nn.Module): """ - Autodecoder module + Autodecoder which maps a list of integer or string keys to optimizable embeddings. Settings: encoding_dim: Embedding dimension for the decoder. @@ -43,32 +42,32 @@ class Autodecoder(Configurable, torch.nn.Module): # weight has been initialised from Normal(0, 1) self._autodecoder_codes.weight *= self.init_scale - self._sequence_map = self._build_sequence_map() + self._key_map = self._build_key_map() # Make sure to register hooks for correct handling of saving/loading - # the module's _sequence_map. - self._register_load_state_dict_pre_hook(self._load_sequence_map_hook) - self._register_state_dict_hook(_save_sequence_map_hook) + # the module's _key_map. + self._register_load_state_dict_pre_hook(self._load_key_map_hook) + self._register_state_dict_hook(_save_key_map_hook) - def _build_sequence_map( - self, sequence_map_dict: Optional[Dict[str, int]] = None + def _build_key_map( + self, key_map_dict: Optional[Dict[str, int]] = None ) -> Dict[str, int]: """ Args: - sequence_map_dict: A dictionary used to initialize the sequence_map. + key_map_dict: A dictionary used to initialize the key_map. Returns: - sequence_map: a dictionary of key: id pairs. + key_map: a dictionary of key: id pairs. """ # increments the counter when asked for a new value - sequence_map = defaultdict(iter(range(self.n_instances)).__next__) - if sequence_map_dict is not None: - # Assign all keys from the loaded sequence_map_dict to self._sequence_map. + key_map = defaultdict(iter(range(self.n_instances)).__next__) + if key_map_dict is not None: + # Assign all keys from the loaded key_map_dict to self._key_map. # Since this is done in the original order, it should generate # the same set of key:id pairs. We check this with an assert to be sure. - for x, x_id in sequence_map_dict.items(): - x_id_ = sequence_map[x] + for x, x_id in key_map_dict.items(): + x_id_ = key_map[x] assert x_id == x_id_ - return sequence_map + return key_map def calc_squared_encoding_norm(self): if self.n_instances <= 0: @@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module): def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]: """ Args: - x: A batch of `N` sequence identifiers. Either a long tensor of size + x: A batch of `N` identifiers. Either a long tensor of size `(N,)` keys in [0, n_instances), or a list of `N` string keys that are hashed to codes (without collisions). Returns: codes: A tensor of shape `(N, self.encoding_dim)` containing the - sequence-specific autodecoder codes. + key-specific autodecoder codes. """ if self.n_instances == 0: return None @@ -103,7 +102,7 @@ class Autodecoder(Configurable, torch.nn.Module): # `Tensor`. x = torch.tensor( # pyre-ignore[29] - [self._sequence_map[elem] for elem in x], + [self._key_map[elem] for elem in x], dtype=torch.long, device=next(self.parameters()).device, ) @@ -113,7 +112,7 @@ class Autodecoder(Configurable, torch.nn.Module): # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. return self._autodecoder_codes(x) - def _load_sequence_map_hook( + def _load_key_map_hook( self, state_dict, prefix, @@ -142,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module): :meth:`~torch.nn.Module.load_state_dict` Returns: - Constructed sequence_map if it exists in the state_dict + Constructed key_map if it exists in the state_dict else raises a warning only. """ - sequence_map_key = prefix + "_sequence_map" - if sequence_map_key in state_dict: - sequence_map_dict = state_dict.pop(sequence_map_key) - self._sequence_map = self._build_sequence_map( - sequence_map_dict=sequence_map_dict - ) + key_map_key = prefix + "_key_map" + if key_map_key in state_dict: + key_map_dict = state_dict.pop(key_map_key) + self._key_map = self._build_key_map(key_map_dict=key_map_dict) else: - warnings.warn("No sequence map in Autodecoder state dict!") + warnings.warn("No key map in Autodecoder state dict!") -def _save_sequence_map_hook( +def _save_key_map_hook( self, state_dict, prefix, @@ -169,6 +166,6 @@ def _save_sequence_map_hook( module local_metadata (dict): a dict containing the metadata for this module. """ - sequence_map_key = prefix + "_sequence_map" - sequence_map_dict = dict(self._sequence_map.items()) - state_dict[sequence_map_key] = sequence_map_dict + key_map_key = prefix + "_key_map" + key_map_dict = dict(self._key_map.items()) + state_dict[key_map_key] = key_map_dict diff --git a/pytorch3d/implicitron/models/global_encoder/global_encoder.py b/pytorch3d/implicitron/models/global_encoder/global_encoder.py new file mode 100644 index 00000000..37a6d7d6 --- /dev/null +++ b/pytorch3d/implicitron/models/global_encoder/global_encoder.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Union + +import torch +from pytorch3d.implicitron.tools.config import ( + registry, + ReplaceableBase, + run_auto_creation, +) +from pytorch3d.renderer.implicit import HarmonicEmbedding + +from .autodecoder import Autodecoder + + +class GlobalEncoderBase(ReplaceableBase): + """ + A base class for implementing encoders of global frame-specific quantities. + + The latter includes e.g. the harmonic encoding of a frame timestamp + (`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence + (`SequenceAutodecoder`). + """ + + def __init__(self) -> None: + super().__init__() + + def get_encoding_dim(self): + """ + Returns the dimensionality of the returned encoding. + """ + raise NotImplementedError() + + def calc_squared_encoding_norm(self): + """ + Calculates the squared norm of the encoding. + """ + raise NotImplementedError() + + def forward(self, **kwargs) -> torch.Tensor: + """ + Given a set of inputs to encode, generates a tensor containing the encoding. + + Returns: + encoding: The tensor containing the global encoding. + """ + raise NotImplementedError() + + +# TODO: probabilistic embeddings? +@registry.register +class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13 + """ + A global encoder implementation which provides an autodecoder encoding + of the frame's sequence identifier. + """ + + autodecoder: Autodecoder + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + + def get_encoding_dim(self): + return self.autodecoder.get_encoding_dim() + + def forward( + self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs + ) -> torch.Tensor: + + # run dtype checks and pass sequence_name to self.autodecoder + return self.autodecoder(sequence_name) + + def calc_squared_encoding_norm(self): + return self.autodecoder.calc_squared_encoding_norm() + + +@registry.register +class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module): + """ + A global encoder implementation which provides harmonic embeddings + of each frame's timestamp. + """ + + n_harmonic_functions: int = 10 + append_input: bool = True + time_divisor: float = 1.0 + + def __post_init__(self): + super().__init__() + self._harmonic_embedding = HarmonicEmbedding( + n_harmonic_functions=self.n_harmonic_functions, + append_input=self.append_input, + ) + + def get_encoding_dim(self): + return self._harmonic_embedding.get_output_dim(1) + + def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor: + if frame_timestamp.shape[-1] != 1: + raise ValueError("Frame timestamp's last dimensions should be one.") + time = frame_timestamp / self.time_divisor + return self._harmonic_embedding(time) # pyre-ignore: 29 + + def calc_squared_encoding_norm(self): + return 0.0 diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index f1826ae2..667225e7 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16 n_train_target_views: 1 sampling_mode_training: mask_sample sampling_mode_evaluation: full_grid +global_encoder_class_type: SequenceAutodecoder raysampler_class_type: AdaptiveRaySampler renderer_class_type: LSTMRenderer image_feature_extractor_class_type: ResNetFeatureExtractor @@ -48,11 +49,12 @@ log_vars: - objective - epoch - sec/it -sequence_autodecoder_args: - encoding_dim: 0 - n_instances: 0 - init_scale: 1.0 - ignore_input: false +global_encoder_SequenceAutodecoder_args: + autodecoder_args: + encoding_dim: 0 + n_instances: 0 + init_scale: 1.0 + ignore_input: false raysampler_AdaptiveRaySampler_args: image_width: 400 image_height: 400 diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py index bef76651..e440c460 100644 --- a/tests/implicitron/test_config_use.py +++ b/tests/implicitron/test_config_use.py @@ -7,11 +7,13 @@ import unittest from omegaconf import OmegaConf -from pytorch3d.implicitron.models.autodecoder import Autodecoder from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( ResNetFeatureExtractor, ) from pytorch3d.implicitron.models.generic_model import GenericModel +from pytorch3d.implicitron.models.global_encoder.global_encoder import ( + SequenceAutodecoder, +) from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( IdrFeatureField, ) @@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase): self.assertIsInstance( gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction ) - self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) + self.assertIsNone(gm.global_encoder) self.assertFalse(hasattr(gm, "implicit_function")) self.assertIsNone(gm.view_pooler) self.assertIsNone(gm.image_feature_extractor) @@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase): ) args.image_feature_extractor_class_type = "ResNetFeatureExtractor" args.implicit_function_class_type = "IdrFeatureField" + args.global_encoder_class_type = "SequenceAutodecoder" idr_args = args.implicit_function_IdrFeatureField_args idr_args.n_harmonic_functions_xyz = 1729 @@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase): ) self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729) - self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) + self.assertIsInstance(gm.global_encoder, SequenceAutodecoder) self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) self.assertFalse(hasattr(gm, "implicit_function")) @@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase): remove_unused_components(instance_args) yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) if DEBUG: + print(DATA_DIR) (DATA_DIR / "overrides.yaml_").write_text(yaml) self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())