From 02c0254f7f9815b2354183b0ba0e479810824a58 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 5 Aug 2022 03:33:30 -0700 Subject: [PATCH] more globalencoder followup Summary: remove n_instances==0 special case, standardise args for GlobalEncoderBase's forward. Reviewed By: shapovalov Differential Revision: D37817340 fbshipit-source-id: 0aac5fbc7c336d09be9d412cffff5712bda27290 --- .../implicitron_trainer/tests/experiment.yaml | 2 +- .../models/global_encoder/autodecoder.py | 14 +++------- .../models/global_encoder/global_encoder.py | 27 ++++++++++++++++--- tests/implicitron/data/overrides.yaml | 2 +- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index f4bf12e9..e1ddc7ca 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -187,7 +187,7 @@ model_factory_ImplicitronModelFactory_args: global_encoder_SequenceAutodecoder_args: autodecoder_args: encoding_dim: 0 - n_instances: 0 + n_instances: 1 init_scale: 1.0 ignore_input: false raysampler_AdaptiveRaySampler_args: diff --git a/pytorch3d/implicitron/models/global_encoder/autodecoder.py b/pytorch3d/implicitron/models/global_encoder/autodecoder.py index 5a072d54..b03d5588 100644 --- a/pytorch3d/implicitron/models/global_encoder/autodecoder.py +++ b/pytorch3d/implicitron/models/global_encoder/autodecoder.py @@ -24,15 +24,16 @@ class Autodecoder(Configurable, torch.nn.Module): """ encoding_dim: int = 0 - n_instances: int = 0 + n_instances: int = 1 init_scale: float = 1.0 ignore_input: bool = False def __post_init__(self): super().__init__() + if self.n_instances <= 0: - # Do not init the codes at all in case we have 0 instances. - return + raise ValueError(f"Invalid n_instances {self.n_instances}") + self._autodecoder_codes = torch.nn.Embedding( self.n_instances, self.encoding_dim, @@ -70,13 +71,9 @@ class Autodecoder(Configurable, torch.nn.Module): return key_map def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: - if self.n_instances <= 0: - return None return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16] def get_encoding_dim(self) -> int: - if self.n_instances <= 0: - return 0 return self.encoding_dim def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]: @@ -90,9 +87,6 @@ class Autodecoder(Configurable, torch.nn.Module): codes: A tensor of shape `(N, self.encoding_dim)` containing the key-specific autodecoder codes. """ - if self.n_instances == 0: - return None - if self.ignore_input: x = ["singleton"] diff --git a/pytorch3d/implicitron/models/global_encoder/global_encoder.py b/pytorch3d/implicitron/models/global_encoder/global_encoder.py index 3b919b8c..641433ad 100644 --- a/pytorch3d/implicitron/models/global_encoder/global_encoder.py +++ b/pytorch3d/implicitron/models/global_encoder/global_encoder.py @@ -42,7 +42,13 @@ class GlobalEncoderBase(ReplaceableBase): """ raise NotImplementedError() - def forward(self, **kwargs) -> torch.Tensor: + def forward( + self, + *, + frame_timestamp: Optional[torch.Tensor] = None, + sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None, + **kwargs, + ) -> torch.Tensor: """ Given a set of inputs to encode, generates a tensor containing the encoding. @@ -70,9 +76,14 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1 return self.autodecoder.get_encoding_dim() def forward( - self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs + self, + *, + frame_timestamp: Optional[torch.Tensor] = None, + sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None, + **kwargs, ) -> torch.Tensor: - + if sequence_name is None: + raise ValueError("sequence_name must be provided.") # run dtype checks and pass sequence_name to self.autodecoder return self.autodecoder(sequence_name) @@ -101,7 +112,15 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module): def get_encoding_dim(self): return self._harmonic_embedding.get_output_dim(1) - def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + *, + frame_timestamp: Optional[torch.Tensor] = None, + sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None, + **kwargs, + ) -> torch.Tensor: + if frame_timestamp is None: + raise ValueError("frame_timestamp must be provided.") if frame_timestamp.shape[-1] != 1: raise ValueError("Frame timestamp's last dimensions should be one.") time = frame_timestamp / self.time_divisor diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index 7bbd5df0..1414a5e8 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -52,7 +52,7 @@ loss_weights: global_encoder_SequenceAutodecoder_args: autodecoder_args: encoding_dim: 0 - n_instances: 0 + n_instances: 1 init_scale: 1.0 ignore_input: false raysampler_AdaptiveRaySampler_args: