From b95ec190af7794f40562dff1082b59c4b4590590 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 13 Jul 2022 06:07:02 -0700 Subject: [PATCH] followups to D37592429 Summary: Fixing comments on D37592429 (https://github.com/facebookresearch/pytorch3d/commit/0dce883241ae638b9fa824f34fca9590d5f0782c) Reviewed By: shapovalov Differential Revision: D37752367 fbshipit-source-id: 40aa7ee4dc0c5b8b7a84a09d13a3933a9e3afedd --- pytorch3d/implicitron/models/generic_model.py | 14 +++++++------- .../models/global_encoder/autodecoder.py | 4 ++-- .../models/global_encoder/global_encoder.py | 15 ++++++++------- pytorch3d/implicitron/models/metrics.py | 2 +- tests/implicitron/test_config_use.py | 1 - 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 78718e90..9952a67d 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -368,7 +368,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 # A helper function for selecting n_target first elements from the input # where the latter can be None. - def _safe_slice_targets( + def safe_slice_targets( tensor: Optional[Union[torch.Tensor, List[str]]], ) -> Optional[Union[torch.Tensor, List[str]]]: return None if tensor is None else tensor[:n_targets] @@ -423,8 +423,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 global_code = None if self.global_encoder is not None: global_code = self.global_encoder( # pyre-fixme[29] - sequence_name=_safe_slice_targets(sequence_name), - frame_timestamp=_safe_slice_targets(frame_timestamp), + sequence_name=safe_slice_targets(sequence_name), + frame_timestamp=safe_slice_targets(frame_timestamp), ) custom_args["global_code"] = global_code @@ -469,10 +469,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 results=preds, raymarched=rendered, xys=ray_bundle.xys, - image_rgb=_safe_slice_targets(image_rgb), - depth_map=_safe_slice_targets(depth_map), - fg_probability=_safe_slice_targets(fg_probability), - mask_crop=_safe_slice_targets(mask_crop), + image_rgb=safe_slice_targets(image_rgb), + depth_map=safe_slice_targets(depth_map), + fg_probability=safe_slice_targets(fg_probability), + mask_crop=safe_slice_targets(mask_crop), ) ) diff --git a/pytorch3d/implicitron/models/global_encoder/autodecoder.py b/pytorch3d/implicitron/models/global_encoder/autodecoder.py index 3089ac68..e93e15e4 100644 --- a/pytorch3d/implicitron/models/global_encoder/autodecoder.py +++ b/pytorch3d/implicitron/models/global_encoder/autodecoder.py @@ -69,10 +69,10 @@ class Autodecoder(Configurable, torch.nn.Module): assert x_id == x_id_ return key_map - def calc_squared_encoding_norm(self): + def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: if self.n_instances <= 0: return None - return (self._autodecoder_codes.weight**2).mean() + return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16] def get_encoding_dim(self) -> int: if self.n_instances <= 0: diff --git a/pytorch3d/implicitron/models/global_encoder/global_encoder.py b/pytorch3d/implicitron/models/global_encoder/global_encoder.py index 37a6d7d6..3b919b8c 100644 --- a/pytorch3d/implicitron/models/global_encoder/global_encoder.py +++ b/pytorch3d/implicitron/models/global_encoder/global_encoder.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Union +from typing import List, Optional, Union import torch from pytorch3d.implicitron.tools.config import ( @@ -35,9 +35,10 @@ class GlobalEncoderBase(ReplaceableBase): """ raise NotImplementedError() - def calc_squared_encoding_norm(self): + def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: """ - Calculates the squared norm of the encoding. + Calculates the squared norm of the encoding to report as the + `autodecoder_norm` loss of the model, as a zero dimensional tensor. """ raise NotImplementedError() @@ -75,8 +76,8 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1 # run dtype checks and pass sequence_name to self.autodecoder return self.autodecoder(sequence_name) - def calc_squared_encoding_norm(self): - return self.autodecoder.calc_squared_encoding_norm() + def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: + return self.autodecoder.calculate_squared_encoding_norm() @registry.register @@ -106,5 +107,5 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module): time = frame_timestamp / self.time_divisor return self._harmonic_embedding(time) # pyre-ignore: 29 - def calc_squared_encoding_norm(self): - return 0.0 + def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: + return None diff --git a/pytorch3d/implicitron/models/metrics.py b/pytorch3d/implicitron/models/metrics.py index 168c8bde..63a19724 100644 --- a/pytorch3d/implicitron/models/metrics.py +++ b/pytorch3d/implicitron/models/metrics.py @@ -126,7 +126,7 @@ class RegularizationMetrics(RegularizationMetricsBase): """ metrics = {} if getattr(model, "sequence_autodecoder", None) is not None: - ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm() + ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm() if ad_penalty is not None: metrics["autodecoder_norm"] = ad_penalty diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py index e440c460..e3b63147 100644 --- a/tests/implicitron/test_config_use.py +++ b/tests/implicitron/test_config_use.py @@ -90,6 +90,5 @@ class TestGenericModel(unittest.TestCase): remove_unused_components(instance_args) yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) if DEBUG: - print(DATA_DIR) (DATA_DIR / "overrides.yaml_").write_text(yaml) self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())