more globalencoder followup

Summary: remove n_instances==0 special case, standardise args for GlobalEncoderBase's forward.

Reviewed By: shapovalov

Differential Revision: D37817340

fbshipit-source-id: 0aac5fbc7c336d09be9d412cffff5712bda27290
This commit is contained in:
Jeremy Reizenstein 2022-08-05 03:33:30 -07:00 committed by Facebook GitHub Bot
parent 9d888f1332
commit 02c0254f7f
4 changed files with 29 additions and 16 deletions

View File

@ -187,7 +187,7 @@ model_factory_ImplicitronModelFactory_args:
global_encoder_SequenceAutodecoder_args: global_encoder_SequenceAutodecoder_args:
autodecoder_args: autodecoder_args:
encoding_dim: 0 encoding_dim: 0
n_instances: 0 n_instances: 1
init_scale: 1.0 init_scale: 1.0
ignore_input: false ignore_input: false
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args:

View File

@ -24,15 +24,16 @@ class Autodecoder(Configurable, torch.nn.Module):
""" """
encoding_dim: int = 0 encoding_dim: int = 0
n_instances: int = 0 n_instances: int = 1
init_scale: float = 1.0 init_scale: float = 1.0
ignore_input: bool = False ignore_input: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__() super().__init__()
if self.n_instances <= 0: if self.n_instances <= 0:
# Do not init the codes at all in case we have 0 instances. raise ValueError(f"Invalid n_instances {self.n_instances}")
return
self._autodecoder_codes = torch.nn.Embedding( self._autodecoder_codes = torch.nn.Embedding(
self.n_instances, self.n_instances,
self.encoding_dim, self.encoding_dim,
@ -70,13 +71,9 @@ class Autodecoder(Configurable, torch.nn.Module):
return key_map return key_map
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
if self.n_instances <= 0:
return None
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16] return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
def get_encoding_dim(self) -> int: def get_encoding_dim(self) -> int:
if self.n_instances <= 0:
return 0
return self.encoding_dim return self.encoding_dim
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]: def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
@ -90,9 +87,6 @@ class Autodecoder(Configurable, torch.nn.Module):
codes: A tensor of shape `(N, self.encoding_dim)` containing the codes: A tensor of shape `(N, self.encoding_dim)` containing the
key-specific autodecoder codes. key-specific autodecoder codes.
""" """
if self.n_instances == 0:
return None
if self.ignore_input: if self.ignore_input:
x = ["singleton"] x = ["singleton"]

View File

@ -42,7 +42,13 @@ class GlobalEncoderBase(ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def forward(self, **kwargs) -> torch.Tensor: def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
""" """
Given a set of inputs to encode, generates a tensor containing the encoding. Given a set of inputs to encode, generates a tensor containing the encoding.
@ -70,9 +76,14 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
return self.autodecoder.get_encoding_dim() return self.autodecoder.get_encoding_dim()
def forward( def forward(
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
if sequence_name is None:
raise ValueError("sequence_name must be provided.")
# run dtype checks and pass sequence_name to self.autodecoder # run dtype checks and pass sequence_name to self.autodecoder
return self.autodecoder(sequence_name) return self.autodecoder(sequence_name)
@ -101,7 +112,15 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
def get_encoding_dim(self): def get_encoding_dim(self):
return self._harmonic_embedding.get_output_dim(1) return self._harmonic_embedding.get_output_dim(1)
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor: def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
if frame_timestamp is None:
raise ValueError("frame_timestamp must be provided.")
if frame_timestamp.shape[-1] != 1: if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.") raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor time = frame_timestamp / self.time_divisor

View File

@ -52,7 +52,7 @@ loss_weights:
global_encoder_SequenceAutodecoder_args: global_encoder_SequenceAutodecoder_args:
autodecoder_args: autodecoder_args:
encoding_dim: 0 encoding_dim: 0
n_instances: 0 n_instances: 1
init_scale: 1.0 init_scale: 1.0
ignore_input: false ignore_input: false
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args: