followups to D37592429

Summary: Fixing comments on D37592429 (0dce883241)

Reviewed By: shapovalov

Differential Revision: D37752367

fbshipit-source-id: 40aa7ee4dc0c5b8b7a84a09d13a3933a9e3afedd
This commit is contained in:
Jeremy Reizenstein 2022-07-13 06:07:02 -07:00 committed by Facebook GitHub Bot
parent 55f67b0d18
commit b95ec190af
5 changed files with 18 additions and 18 deletions

View File

@ -368,7 +368,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# A helper function for selecting n_target first elements from the input # A helper function for selecting n_target first elements from the input
# where the latter can be None. # where the latter can be None.
def _safe_slice_targets( def safe_slice_targets(
tensor: Optional[Union[torch.Tensor, List[str]]], tensor: Optional[Union[torch.Tensor, List[str]]],
) -> Optional[Union[torch.Tensor, List[str]]]: ) -> Optional[Union[torch.Tensor, List[str]]]:
return None if tensor is None else tensor[:n_targets] return None if tensor is None else tensor[:n_targets]
@ -423,8 +423,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
global_code = None global_code = None
if self.global_encoder is not None: if self.global_encoder is not None:
global_code = self.global_encoder( # pyre-fixme[29] global_code = self.global_encoder( # pyre-fixme[29]
sequence_name=_safe_slice_targets(sequence_name), sequence_name=safe_slice_targets(sequence_name),
frame_timestamp=_safe_slice_targets(frame_timestamp), frame_timestamp=safe_slice_targets(frame_timestamp),
) )
custom_args["global_code"] = global_code custom_args["global_code"] = global_code
@ -469,10 +469,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
results=preds, results=preds,
raymarched=rendered, raymarched=rendered,
xys=ray_bundle.xys, xys=ray_bundle.xys,
image_rgb=_safe_slice_targets(image_rgb), image_rgb=safe_slice_targets(image_rgb),
depth_map=_safe_slice_targets(depth_map), depth_map=safe_slice_targets(depth_map),
fg_probability=_safe_slice_targets(fg_probability), fg_probability=safe_slice_targets(fg_probability),
mask_crop=_safe_slice_targets(mask_crop), mask_crop=safe_slice_targets(mask_crop),
) )
) )

View File

@ -69,10 +69,10 @@ class Autodecoder(Configurable, torch.nn.Module):
assert x_id == x_id_ assert x_id == x_id_
return key_map return key_map
def calc_squared_encoding_norm(self): def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
if self.n_instances <= 0: if self.n_instances <= 0:
return None return None
return (self._autodecoder_codes.weight**2).mean() return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
def get_encoding_dim(self) -> int: def get_encoding_dim(self) -> int:
if self.n_instances <= 0: if self.n_instances <= 0:

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List, Union from typing import List, Optional, Union
import torch import torch
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
@ -35,9 +35,10 @@ class GlobalEncoderBase(ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def calc_squared_encoding_norm(self): def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
""" """
Calculates the squared norm of the encoding. Calculates the squared norm of the encoding to report as the
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -75,8 +76,8 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
# run dtype checks and pass sequence_name to self.autodecoder # run dtype checks and pass sequence_name to self.autodecoder
return self.autodecoder(sequence_name) return self.autodecoder(sequence_name)
def calc_squared_encoding_norm(self): def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return self.autodecoder.calc_squared_encoding_norm() return self.autodecoder.calculate_squared_encoding_norm()
@registry.register @registry.register
@ -106,5 +107,5 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
time = frame_timestamp / self.time_divisor time = frame_timestamp / self.time_divisor
return self._harmonic_embedding(time) # pyre-ignore: 29 return self._harmonic_embedding(time) # pyre-ignore: 29
def calc_squared_encoding_norm(self): def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return 0.0 return None

View File

@ -126,7 +126,7 @@ class RegularizationMetrics(RegularizationMetricsBase):
""" """
metrics = {} metrics = {}
if getattr(model, "sequence_autodecoder", None) is not None: if getattr(model, "sequence_autodecoder", None) is not None:
ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm() ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm()
if ad_penalty is not None: if ad_penalty is not None:
metrics["autodecoder_norm"] = ad_penalty metrics["autodecoder_norm"] = ad_penalty

View File

@ -90,6 +90,5 @@ class TestGenericModel(unittest.TestCase):
remove_unused_components(instance_args) remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG: if DEBUG:
print(DATA_DIR)
(DATA_DIR / "overrides.yaml_").write_text(yaml) (DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text()) self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())