mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
followups to D37592429
Summary: Fixing comments on D37592429 (0dce883241
)
Reviewed By: shapovalov
Differential Revision: D37752367
fbshipit-source-id: 40aa7ee4dc0c5b8b7a84a09d13a3933a9e3afedd
This commit is contained in:
parent
55f67b0d18
commit
b95ec190af
@ -368,7 +368,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
|
||||
# A helper function for selecting n_target first elements from the input
|
||||
# where the latter can be None.
|
||||
def _safe_slice_targets(
|
||||
def safe_slice_targets(
|
||||
tensor: Optional[Union[torch.Tensor, List[str]]],
|
||||
) -> Optional[Union[torch.Tensor, List[str]]]:
|
||||
return None if tensor is None else tensor[:n_targets]
|
||||
@ -423,8 +423,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
global_code = None
|
||||
if self.global_encoder is not None:
|
||||
global_code = self.global_encoder( # pyre-fixme[29]
|
||||
sequence_name=_safe_slice_targets(sequence_name),
|
||||
frame_timestamp=_safe_slice_targets(frame_timestamp),
|
||||
sequence_name=safe_slice_targets(sequence_name),
|
||||
frame_timestamp=safe_slice_targets(frame_timestamp),
|
||||
)
|
||||
custom_args["global_code"] = global_code
|
||||
|
||||
@ -469,10 +469,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
results=preds,
|
||||
raymarched=rendered,
|
||||
xys=ray_bundle.xys,
|
||||
image_rgb=_safe_slice_targets(image_rgb),
|
||||
depth_map=_safe_slice_targets(depth_map),
|
||||
fg_probability=_safe_slice_targets(fg_probability),
|
||||
mask_crop=_safe_slice_targets(mask_crop),
|
||||
image_rgb=safe_slice_targets(image_rgb),
|
||||
depth_map=safe_slice_targets(depth_map),
|
||||
fg_probability=safe_slice_targets(fg_probability),
|
||||
mask_crop=safe_slice_targets(mask_crop),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -69,10 +69,10 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
assert x_id == x_id_
|
||||
return key_map
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
if self.n_instances <= 0:
|
||||
return None
|
||||
return (self._autodecoder_codes.weight**2).mean()
|
||||
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
|
||||
|
||||
def get_encoding_dim(self) -> int:
|
||||
if self.n_instances <= 0:
|
||||
|
@ -4,7 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@ -35,9 +35,10 @@ class GlobalEncoderBase(ReplaceableBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Calculates the squared norm of the encoding.
|
||||
Calculates the squared norm of the encoding to report as the
|
||||
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -75,8 +76,8 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
|
||||
# run dtype checks and pass sequence_name to self.autodecoder
|
||||
return self.autodecoder(sequence_name)
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
return self.autodecoder.calc_squared_encoding_norm()
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
return self.autodecoder.calculate_squared_encoding_norm()
|
||||
|
||||
|
||||
@registry.register
|
||||
@ -106,5 +107,5 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||
time = frame_timestamp / self.time_divisor
|
||||
return self._harmonic_embedding(time) # pyre-ignore: 29
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
return 0.0
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
return None
|
||||
|
@ -126,7 +126,7 @@ class RegularizationMetrics(RegularizationMetricsBase):
|
||||
"""
|
||||
metrics = {}
|
||||
if getattr(model, "sequence_autodecoder", None) is not None:
|
||||
ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm()
|
||||
ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm()
|
||||
if ad_penalty is not None:
|
||||
metrics["autodecoder_norm"] = ad_penalty
|
||||
|
||||
|
@ -90,6 +90,5 @@ class TestGenericModel(unittest.TestCase):
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
print(DATA_DIR)
|
||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||
|
Loading…
x
Reference in New Issue
Block a user