mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
more globalencoder followup
Summary: remove n_instances==0 special case, standardise args for GlobalEncoderBase's forward. Reviewed By: shapovalov Differential Revision: D37817340 fbshipit-source-id: 0aac5fbc7c336d09be9d412cffff5712bda27290
This commit is contained in:
parent
9d888f1332
commit
02c0254f7f
@ -187,7 +187,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
n_instances: 1
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
|
@ -24,15 +24,16 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
"""
|
||||
|
||||
encoding_dim: int = 0
|
||||
n_instances: int = 0
|
||||
n_instances: int = 1
|
||||
init_scale: float = 1.0
|
||||
ignore_input: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
if self.n_instances <= 0:
|
||||
# Do not init the codes at all in case we have 0 instances.
|
||||
return
|
||||
raise ValueError(f"Invalid n_instances {self.n_instances}")
|
||||
|
||||
self._autodecoder_codes = torch.nn.Embedding(
|
||||
self.n_instances,
|
||||
self.encoding_dim,
|
||||
@ -70,13 +71,9 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
return key_map
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
if self.n_instances <= 0:
|
||||
return None
|
||||
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
|
||||
|
||||
def get_encoding_dim(self) -> int:
|
||||
if self.n_instances <= 0:
|
||||
return 0
|
||||
return self.encoding_dim
|
||||
|
||||
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
||||
@ -90,9 +87,6 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
||||
key-specific autodecoder codes.
|
||||
"""
|
||||
if self.n_instances == 0:
|
||||
return None
|
||||
|
||||
if self.ignore_input:
|
||||
x = ["singleton"]
|
||||
|
||||
|
@ -42,7 +42,13 @@ class GlobalEncoderBase(ReplaceableBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, **kwargs) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
frame_timestamp: Optional[torch.Tensor] = None,
|
||||
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Given a set of inputs to encode, generates a tensor containing the encoding.
|
||||
|
||||
@ -70,9 +76,14 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
|
||||
return self.autodecoder.get_encoding_dim()
|
||||
|
||||
def forward(
|
||||
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
|
||||
self,
|
||||
*,
|
||||
frame_timestamp: Optional[torch.Tensor] = None,
|
||||
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if sequence_name is None:
|
||||
raise ValueError("sequence_name must be provided.")
|
||||
# run dtype checks and pass sequence_name to self.autodecoder
|
||||
return self.autodecoder(sequence_name)
|
||||
|
||||
@ -101,7 +112,15 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||
def get_encoding_dim(self):
|
||||
return self._harmonic_embedding.get_output_dim(1)
|
||||
|
||||
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
frame_timestamp: Optional[torch.Tensor] = None,
|
||||
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if frame_timestamp is None:
|
||||
raise ValueError("frame_timestamp must be provided.")
|
||||
if frame_timestamp.shape[-1] != 1:
|
||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||
time = frame_timestamp / self.time_divisor
|
||||
|
@ -52,7 +52,7 @@ loss_weights:
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
n_instances: 1
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
|
Loading…
x
Reference in New Issue
Block a user