mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	more globalencoder followup
Summary: remove n_instances==0 special case, standardise args for GlobalEncoderBase's forward. Reviewed By: shapovalov Differential Revision: D37817340 fbshipit-source-id: 0aac5fbc7c336d09be9d412cffff5712bda27290
This commit is contained in:
		
							parent
							
								
									9d888f1332
								
							
						
					
					
						commit
						02c0254f7f
					
				@ -187,7 +187,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
    global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
      autodecoder_args:
 | 
			
		||||
        encoding_dim: 0
 | 
			
		||||
        n_instances: 0
 | 
			
		||||
        n_instances: 1
 | 
			
		||||
        init_scale: 1.0
 | 
			
		||||
        ignore_input: false
 | 
			
		||||
    raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
 | 
			
		||||
@ -24,15 +24,16 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    encoding_dim: int = 0
 | 
			
		||||
    n_instances: int = 0
 | 
			
		||||
    n_instances: int = 1
 | 
			
		||||
    init_scale: float = 1.0
 | 
			
		||||
    ignore_input: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        if self.n_instances <= 0:
 | 
			
		||||
            # Do not init the codes at all in case we have 0 instances.
 | 
			
		||||
            return
 | 
			
		||||
            raise ValueError(f"Invalid n_instances {self.n_instances}")
 | 
			
		||||
 | 
			
		||||
        self._autodecoder_codes = torch.nn.Embedding(
 | 
			
		||||
            self.n_instances,
 | 
			
		||||
            self.encoding_dim,
 | 
			
		||||
@ -70,13 +71,9 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
        return key_map
 | 
			
		||||
 | 
			
		||||
    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
 | 
			
		||||
        if self.n_instances <= 0:
 | 
			
		||||
            return None
 | 
			
		||||
        return (self._autodecoder_codes.weight**2).mean()  # pyre-ignore[16]
 | 
			
		||||
 | 
			
		||||
    def get_encoding_dim(self) -> int:
 | 
			
		||||
        if self.n_instances <= 0:
 | 
			
		||||
            return 0
 | 
			
		||||
        return self.encoding_dim
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
 | 
			
		||||
@ -90,9 +87,6 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
            codes: A tensor of shape `(N, self.encoding_dim)` containing the
 | 
			
		||||
                key-specific autodecoder codes.
 | 
			
		||||
        """
 | 
			
		||||
        if self.n_instances == 0:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        if self.ignore_input:
 | 
			
		||||
            x = ["singleton"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -42,7 +42,13 @@ class GlobalEncoderBase(ReplaceableBase):
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def forward(self, **kwargs) -> torch.Tensor:
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        frame_timestamp: Optional[torch.Tensor] = None,
 | 
			
		||||
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Given a set of inputs to encode, generates a tensor containing the encoding.
 | 
			
		||||
 | 
			
		||||
@ -70,9 +76,14 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):  # pyre-ignore: 1
 | 
			
		||||
        return self.autodecoder.get_encoding_dim()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        frame_timestamp: Optional[torch.Tensor] = None,
 | 
			
		||||
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
        if sequence_name is None:
 | 
			
		||||
            raise ValueError("sequence_name must be provided.")
 | 
			
		||||
        # run dtype checks and pass sequence_name to self.autodecoder
 | 
			
		||||
        return self.autodecoder(sequence_name)
 | 
			
		||||
 | 
			
		||||
@ -101,7 +112,15 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
 | 
			
		||||
    def get_encoding_dim(self):
 | 
			
		||||
        return self._harmonic_embedding.get_output_dim(1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        frame_timestamp: Optional[torch.Tensor] = None,
 | 
			
		||||
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        if frame_timestamp is None:
 | 
			
		||||
            raise ValueError("frame_timestamp must be provided.")
 | 
			
		||||
        if frame_timestamp.shape[-1] != 1:
 | 
			
		||||
            raise ValueError("Frame timestamp's last dimensions should be one.")
 | 
			
		||||
        time = frame_timestamp / self.time_divisor
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ loss_weights:
 | 
			
		||||
global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
  autodecoder_args:
 | 
			
		||||
    encoding_dim: 0
 | 
			
		||||
    n_instances: 0
 | 
			
		||||
    n_instances: 1
 | 
			
		||||
    init_scale: 1.0
 | 
			
		||||
    ignore_input: false
 | 
			
		||||
raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user