mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
bg_color for lstm renderer
Summary: Allow specifying a color for non-opaque pixels in LSTMRenderer. Reviewed By: davnov134 Differential Revision: D37172537 fbshipit-source-id: 6039726678bb7947f7d8cd04035b5023b2d5398c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
65f667fd2e
commit
cba26506b6
@@ -21,6 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
"""
|
||||
Implements the learnable LSTM raymarching function from SRN [1].
|
||||
This requires there to be one implicit function, and it is expected to be
|
||||
like SRNImplicitFunction or SRNHyperNetImplicitFunction.
|
||||
|
||||
Settings:
|
||||
num_raymarch_steps: The number of LSTM raymarching steps.
|
||||
@@ -32,6 +34,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
hidden_size: The dimensionality of the LSTM's hidden state.
|
||||
n_feature_channels: The number of feature channels returned by the
|
||||
implicit_function evaluated at each raymarching step.
|
||||
bg_color: If supplied, used as the background color. Otherwise the pixel
|
||||
generator is used everywhere. This has to have length either 1
|
||||
(for a constant value for all output channels) or equal to the number
|
||||
of output channels (which is `out_features` on the pixel generator,
|
||||
typically 3.)
|
||||
verbose: If `True`, logs raymarching debug info.
|
||||
|
||||
References:
|
||||
@@ -45,6 +52,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
init_depth_noise_std: float = 5e-4
|
||||
hidden_size: int = 16
|
||||
n_feature_channels: int = 256
|
||||
bg_color: Optional[List[float]] = None
|
||||
verbose: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -147,6 +155,10 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
if self.bg_color is not None:
|
||||
background = features.new_tensor(self.bg_color)
|
||||
features = torch.lerp(background, features, mask)
|
||||
|
||||
return RendererOutput(
|
||||
features=features[..., 0, :],
|
||||
depths=depth,
|
||||
|
||||
Reference in New Issue
Block a user