bg_color for lstm renderer

Summary: Allow specifying a color for non-opaque pixels in LSTMRenderer.

Reviewed By: davnov134

Differential Revision: D37172537

fbshipit-source-id: 6039726678bb7947f7d8cd04035b5023b2d5398c
This commit is contained in:
Jeremy Reizenstein 2022-06-20 13:46:35 -07:00 committed by Facebook GitHub Bot
parent 65f667fd2e
commit cba26506b6
4 changed files with 36 additions and 39 deletions

View File

@ -85,6 +85,7 @@ generic_model_args:
init_depth_noise_std: 0.0005 init_depth_noise_std: 0.0005
hidden_size: 16 hidden_size: 16
n_feature_channels: 256 n_feature_channels: 256
bg_color: null
verbose: false verbose: false
renderer_MultiPassEmissionAbsorptionRenderer_args: renderer_MultiPassEmissionAbsorptionRenderer_args:
raymarcher_class_type: EmissionAbsorptionRaymarcher raymarcher_class_type: EmissionAbsorptionRaymarcher

View File

@ -21,6 +21,8 @@ logger = logging.getLogger(__name__)
class LSTMRenderer(BaseRenderer, torch.nn.Module): class LSTMRenderer(BaseRenderer, torch.nn.Module):
""" """
Implements the learnable LSTM raymarching function from SRN [1]. Implements the learnable LSTM raymarching function from SRN [1].
This requires there to be one implicit function, and it is expected to be
like SRNImplicitFunction or SRNHyperNetImplicitFunction.
Settings: Settings:
num_raymarch_steps: The number of LSTM raymarching steps. num_raymarch_steps: The number of LSTM raymarching steps.
@ -32,6 +34,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
hidden_size: The dimensionality of the LSTM's hidden state. hidden_size: The dimensionality of the LSTM's hidden state.
n_feature_channels: The number of feature channels returned by the n_feature_channels: The number of feature channels returned by the
implicit_function evaluated at each raymarching step. implicit_function evaluated at each raymarching step.
bg_color: If supplied, used as the background color. Otherwise the pixel
generator is used everywhere. This has to have length either 1
(for a constant value for all output channels) or equal to the number
of output channels (which is `out_features` on the pixel generator,
typically 3.)
verbose: If `True`, logs raymarching debug info. verbose: If `True`, logs raymarching debug info.
References: References:
@ -45,6 +52,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
init_depth_noise_std: float = 5e-4 init_depth_noise_std: float = 5e-4
hidden_size: int = 16 hidden_size: int = 16
n_feature_channels: int = 256 n_feature_channels: int = 256
bg_color: Optional[List[float]] = None
verbose: bool = False verbose: bool = False
def __post_init__(self): def __post_init__(self):
@ -147,6 +155,10 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
dim=-1, keepdim=True dim=-1, keepdim=True
) )
if self.bg_color is not None:
background = features.new_tensor(self.bg_color)
features = torch.lerp(background, features, mask)
return RendererOutput( return RendererOutput(
features=features[..., 0, :], features=features[..., 0, :],
depths=depth, depths=depth,

View File

@ -72,6 +72,7 @@ renderer_LSTMRenderer_args:
init_depth_noise_std: 0.0005 init_depth_noise_std: 0.0005
hidden_size: 16 hidden_size: 16
n_feature_channels: 256 n_feature_channels: 256
bg_color: null
verbose: false verbose: false
image_feature_extractor_ResNetFeatureExtractor_args: image_feature_extractor_ResNetFeatureExtractor_args:
name: resnet34 name: resnet34

View File

@ -7,14 +7,14 @@
import unittest import unittest
import torch import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
SRNHyperNetImplicitFunction, SRNHyperNetImplicitFunction,
SRNImplicitFunction, SRNImplicitFunction,
SRNPixelGenerator, SRNPixelGenerator,
) )
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import get_default_args from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import RayBundle from pytorch3d.renderer import PerspectiveCameras, RayBundle
from tests.common_testing import TestCaseMixin from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3 _BATCH_SIZE: int = 3
@ -69,40 +69,23 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
) )
self.assertIsNone(rays_colors) self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function_optim(self): def test_lstm(self):
# Test optimization loop, requiring that the cache is properly args = get_default_args(GenericModel)
# cleared in new_args_bound args.render_image_height = 80
latent_dim_hypernet = 39 args.render_image_width = 80
hyper_args = {"latent_dim_hypernet": latent_dim_hypernet} args.implicit_function_class_type = "SRNImplicitFunction"
device = torch.device("cuda:0") args.renderer_class_type = "LSTMRenderer"
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) args.raysampler_class_type = "NearFarRaySampler"
bundle = self._get_bundle(device=device) args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args) args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2]
implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args) gm = GenericModel(**args)
implicit_function.to(device) camera = PerspectiveCameras()
implicit_function2.to(device) gm.forward(
camera=camera,
wrapper = ImplicitFunctionWrapper(implicit_function) image_rgb=None,
optimizer = torch.optim.Adam(implicit_function.parameters()) fg_probability=None,
for _step in range(3): sequence_name="",
optimizer.zero_grad() mask_crop=None,
wrapper.bind_args(global_code=global_code) depth_map=None,
rays_densities, _rays_colors = wrapper(bundle) )
wrapper.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer.step()
wrapper2 = ImplicitFunctionWrapper(implicit_function)
optimizer2 = torch.optim.Adam(implicit_function2.parameters())
implicit_function2.load_state_dict(implicit_function.state_dict())
optimizer2.load_state_dict(optimizer.state_dict())
for _step in range(3):
optimizer2.zero_grad()
wrapper2.bind_args(global_code=global_code)
rays_densities, _rays_colors = wrapper2(bundle)
wrapper2.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer2.step()