mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
followups to D37592429
Summary: Fixing comments on D37592429 (0dce883241
)
Reviewed By: shapovalov
Differential Revision: D37752367
fbshipit-source-id: 40aa7ee4dc0c5b8b7a84a09d13a3933a9e3afedd
This commit is contained in:
parent
55f67b0d18
commit
b95ec190af
@ -368,7 +368,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
|
|
||||||
# A helper function for selecting n_target first elements from the input
|
# A helper function for selecting n_target first elements from the input
|
||||||
# where the latter can be None.
|
# where the latter can be None.
|
||||||
def _safe_slice_targets(
|
def safe_slice_targets(
|
||||||
tensor: Optional[Union[torch.Tensor, List[str]]],
|
tensor: Optional[Union[torch.Tensor, List[str]]],
|
||||||
) -> Optional[Union[torch.Tensor, List[str]]]:
|
) -> Optional[Union[torch.Tensor, List[str]]]:
|
||||||
return None if tensor is None else tensor[:n_targets]
|
return None if tensor is None else tensor[:n_targets]
|
||||||
@ -423,8 +423,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
global_code = None
|
global_code = None
|
||||||
if self.global_encoder is not None:
|
if self.global_encoder is not None:
|
||||||
global_code = self.global_encoder( # pyre-fixme[29]
|
global_code = self.global_encoder( # pyre-fixme[29]
|
||||||
sequence_name=_safe_slice_targets(sequence_name),
|
sequence_name=safe_slice_targets(sequence_name),
|
||||||
frame_timestamp=_safe_slice_targets(frame_timestamp),
|
frame_timestamp=safe_slice_targets(frame_timestamp),
|
||||||
)
|
)
|
||||||
custom_args["global_code"] = global_code
|
custom_args["global_code"] = global_code
|
||||||
|
|
||||||
@ -469,10 +469,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
results=preds,
|
results=preds,
|
||||||
raymarched=rendered,
|
raymarched=rendered,
|
||||||
xys=ray_bundle.xys,
|
xys=ray_bundle.xys,
|
||||||
image_rgb=_safe_slice_targets(image_rgb),
|
image_rgb=safe_slice_targets(image_rgb),
|
||||||
depth_map=_safe_slice_targets(depth_map),
|
depth_map=safe_slice_targets(depth_map),
|
||||||
fg_probability=_safe_slice_targets(fg_probability),
|
fg_probability=safe_slice_targets(fg_probability),
|
||||||
mask_crop=_safe_slice_targets(mask_crop),
|
mask_crop=safe_slice_targets(mask_crop),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -69,10 +69,10 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
assert x_id == x_id_
|
assert x_id == x_id_
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def calc_squared_encoding_norm(self):
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
if self.n_instances <= 0:
|
if self.n_instances <= 0:
|
||||||
return None
|
return None
|
||||||
return (self._autodecoder_codes.weight**2).mean()
|
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
|
||||||
|
|
||||||
def get_encoding_dim(self) -> int:
|
def get_encoding_dim(self) -> int:
|
||||||
if self.n_instances <= 0:
|
if self.n_instances <= 0:
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -35,9 +35,10 @@ class GlobalEncoderBase(ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def calc_squared_encoding_norm(self):
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Calculates the squared norm of the encoding.
|
Calculates the squared norm of the encoding to report as the
|
||||||
|
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -75,8 +76,8 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
|
|||||||
# run dtype checks and pass sequence_name to self.autodecoder
|
# run dtype checks and pass sequence_name to self.autodecoder
|
||||||
return self.autodecoder(sequence_name)
|
return self.autodecoder(sequence_name)
|
||||||
|
|
||||||
def calc_squared_encoding_norm(self):
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
return self.autodecoder.calc_squared_encoding_norm()
|
return self.autodecoder.calculate_squared_encoding_norm()
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
@ -106,5 +107,5 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
|||||||
time = frame_timestamp / self.time_divisor
|
time = frame_timestamp / self.time_divisor
|
||||||
return self._harmonic_embedding(time) # pyre-ignore: 29
|
return self._harmonic_embedding(time) # pyre-ignore: 29
|
||||||
|
|
||||||
def calc_squared_encoding_norm(self):
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
return 0.0
|
return None
|
||||||
|
@ -126,7 +126,7 @@ class RegularizationMetrics(RegularizationMetricsBase):
|
|||||||
"""
|
"""
|
||||||
metrics = {}
|
metrics = {}
|
||||||
if getattr(model, "sequence_autodecoder", None) is not None:
|
if getattr(model, "sequence_autodecoder", None) is not None:
|
||||||
ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm()
|
ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm()
|
||||||
if ad_penalty is not None:
|
if ad_penalty is not None:
|
||||||
metrics["autodecoder_norm"] = ad_penalty
|
metrics["autodecoder_norm"] = ad_penalty
|
||||||
|
|
||||||
|
@ -90,6 +90,5 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
remove_unused_components(instance_args)
|
remove_unused_components(instance_args)
|
||||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
print(DATA_DIR)
|
|
||||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user