followups to D37592429

Summary: Fixing comments on D37592429 (0dce883241)

Reviewed By: shapovalov

Differential Revision: D37752367

fbshipit-source-id: 40aa7ee4dc0c5b8b7a84a09d13a3933a9e3afedd
This commit is contained in:
Jeremy Reizenstein 2022-07-13 06:07:02 -07:00 committed by Facebook GitHub Bot
parent 55f67b0d18
commit b95ec190af
5 changed files with 18 additions and 18 deletions

View File

@ -368,7 +368,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# A helper function for selecting n_target first elements from the input
# where the latter can be None.
def _safe_slice_targets(
def safe_slice_targets(
tensor: Optional[Union[torch.Tensor, List[str]]],
) -> Optional[Union[torch.Tensor, List[str]]]:
return None if tensor is None else tensor[:n_targets]
@ -423,8 +423,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
global_code = None
if self.global_encoder is not None:
global_code = self.global_encoder( # pyre-fixme[29]
sequence_name=_safe_slice_targets(sequence_name),
frame_timestamp=_safe_slice_targets(frame_timestamp),
sequence_name=safe_slice_targets(sequence_name),
frame_timestamp=safe_slice_targets(frame_timestamp),
)
custom_args["global_code"] = global_code
@ -469,10 +469,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
results=preds,
raymarched=rendered,
xys=ray_bundle.xys,
image_rgb=_safe_slice_targets(image_rgb),
depth_map=_safe_slice_targets(depth_map),
fg_probability=_safe_slice_targets(fg_probability),
mask_crop=_safe_slice_targets(mask_crop),
image_rgb=safe_slice_targets(image_rgb),
depth_map=safe_slice_targets(depth_map),
fg_probability=safe_slice_targets(fg_probability),
mask_crop=safe_slice_targets(mask_crop),
)
)

View File

@ -69,10 +69,10 @@ class Autodecoder(Configurable, torch.nn.Module):
assert x_id == x_id_
return key_map
def calc_squared_encoding_norm(self):
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
if self.n_instances <= 0:
return None
return (self._autodecoder_codes.weight**2).mean()
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
def get_encoding_dim(self) -> int:
if self.n_instances <= 0:

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Union
from typing import List, Optional, Union
import torch
from pytorch3d.implicitron.tools.config import (
@ -35,9 +35,10 @@ class GlobalEncoderBase(ReplaceableBase):
"""
raise NotImplementedError()
def calc_squared_encoding_norm(self):
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
"""
Calculates the squared norm of the encoding.
Calculates the squared norm of the encoding to report as the
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
"""
raise NotImplementedError()
@ -75,8 +76,8 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
# run dtype checks and pass sequence_name to self.autodecoder
return self.autodecoder(sequence_name)
def calc_squared_encoding_norm(self):
return self.autodecoder.calc_squared_encoding_norm()
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return self.autodecoder.calculate_squared_encoding_norm()
@registry.register
@ -106,5 +107,5 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
time = frame_timestamp / self.time_divisor
return self._harmonic_embedding(time) # pyre-ignore: 29
def calc_squared_encoding_norm(self):
return 0.0
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return None

View File

@ -126,7 +126,7 @@ class RegularizationMetrics(RegularizationMetricsBase):
"""
metrics = {}
if getattr(model, "sequence_autodecoder", None) is not None:
ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm()
ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm()
if ad_penalty is not None:
metrics["autodecoder_norm"] = ad_penalty

View File

@ -90,6 +90,5 @@ class TestGenericModel(unittest.TestCase):
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
print(DATA_DIR)
(DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())