mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-27 08:46:00 +08:00
Summary: See D52890934 Reviewed By: malfet, r-barnes Differential Revision: D66245100 fbshipit-source-id: 019058106ac7eaacf29c1c55912922ea55894d23
130 lines
3.9 KiB
Python
130 lines
3.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# pyre-unsafe
|
|
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.tools.config import (
|
|
registry,
|
|
ReplaceableBase,
|
|
run_auto_creation,
|
|
)
|
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
|
|
|
from .autodecoder import Autodecoder
|
|
|
|
|
|
class GlobalEncoderBase(ReplaceableBase):
|
|
"""
|
|
A base class for implementing encoders of global frame-specific quantities.
|
|
|
|
The latter includes e.g. the harmonic encoding of a frame timestamp
|
|
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
|
|
(`SequenceAutodecoder`).
|
|
"""
|
|
|
|
def get_encoding_dim(self):
|
|
"""
|
|
Returns the dimensionality of the returned encoding.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
|
"""
|
|
Calculates the squared norm of the encoding to report as the
|
|
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
frame_timestamp: Optional[torch.Tensor] = None,
|
|
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Given a set of inputs to encode, generates a tensor containing the encoding.
|
|
|
|
Returns:
|
|
encoding: The tensor containing the global encoding.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
# TODO: probabilistic embeddings?
|
|
@registry.register
|
|
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):
|
|
"""
|
|
A global encoder implementation which provides an autodecoder encoding
|
|
of the frame's sequence identifier.
|
|
"""
|
|
|
|
# pyre-fixme[13]: Attribute `autodecoder` is never initialized.
|
|
autodecoder: Autodecoder
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
def get_encoding_dim(self):
|
|
return self.autodecoder.get_encoding_dim()
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
frame_timestamp: Optional[torch.Tensor] = None,
|
|
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
if sequence_name is None:
|
|
raise ValueError("sequence_name must be provided.")
|
|
# run dtype checks and pass sequence_name to self.autodecoder
|
|
return self.autodecoder(sequence_name)
|
|
|
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
|
return self.autodecoder.calculate_squared_encoding_norm()
|
|
|
|
|
|
@registry.register
|
|
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
|
"""
|
|
A global encoder implementation which provides harmonic embeddings
|
|
of each frame's timestamp.
|
|
"""
|
|
|
|
n_harmonic_functions: int = 10
|
|
append_input: bool = True
|
|
time_divisor: float = 1.0
|
|
|
|
def __post_init__(self):
|
|
self._harmonic_embedding = HarmonicEmbedding(
|
|
n_harmonic_functions=self.n_harmonic_functions,
|
|
append_input=self.append_input,
|
|
)
|
|
|
|
def get_encoding_dim(self):
|
|
return self._harmonic_embedding.get_output_dim(1)
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
frame_timestamp: Optional[torch.Tensor] = None,
|
|
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
if frame_timestamp is None:
|
|
raise ValueError("frame_timestamp must be provided.")
|
|
if frame_timestamp.shape[-1] != 1:
|
|
raise ValueError("Frame timestamp's last dimensions should be one.")
|
|
time = frame_timestamp / self.time_divisor
|
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
|
return self._harmonic_embedding(time)
|
|
|
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
|
return None
|