From cba26506b6fe8a98695f50673cb20d9597d87551 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 20 Jun 2022 13:46:35 -0700 Subject: [PATCH] bg_color for lstm renderer Summary: Allow specifying a color for non-opaque pixels in LSTMRenderer. Reviewed By: davnov134 Differential Revision: D37172537 fbshipit-source-id: 6039726678bb7947f7d8cd04035b5023b2d5398c --- .../implicitron_trainer/tests/experiment.yaml | 1 + .../models/renderer/lstm_renderer.py | 12 ++++ tests/implicitron/data/overrides.yaml | 1 + tests/implicitron/test_srn.py | 61 +++++++------------ 4 files changed, 36 insertions(+), 39 deletions(-) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index d560c8d7..eda38f3a 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -85,6 +85,7 @@ generic_model_args: init_depth_noise_std: 0.0005 hidden_size: 16 n_feature_channels: 256 + bg_color: null verbose: false renderer_MultiPassEmissionAbsorptionRenderer_args: raymarcher_class_type: EmissionAbsorptionRaymarcher diff --git a/pytorch3d/implicitron/models/renderer/lstm_renderer.py b/pytorch3d/implicitron/models/renderer/lstm_renderer.py index 5121ea16..64cd89cf 100644 --- a/pytorch3d/implicitron/models/renderer/lstm_renderer.py +++ b/pytorch3d/implicitron/models/renderer/lstm_renderer.py @@ -21,6 +21,8 @@ logger = logging.getLogger(__name__) class LSTMRenderer(BaseRenderer, torch.nn.Module): """ Implements the learnable LSTM raymarching function from SRN [1]. + This requires there to be one implicit function, and it is expected to be + like SRNImplicitFunction or SRNHyperNetImplicitFunction. Settings: num_raymarch_steps: The number of LSTM raymarching steps. @@ -32,6 +34,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): hidden_size: The dimensionality of the LSTM's hidden state. n_feature_channels: The number of feature channels returned by the implicit_function evaluated at each raymarching step. + bg_color: If supplied, used as the background color. Otherwise the pixel + generator is used everywhere. This has to have length either 1 + (for a constant value for all output channels) or equal to the number + of output channels (which is `out_features` on the pixel generator, + typically 3.) verbose: If `True`, logs raymarching debug info. References: @@ -45,6 +52,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): init_depth_noise_std: float = 5e-4 hidden_size: int = 16 n_feature_channels: int = 256 + bg_color: Optional[List[float]] = None verbose: bool = False def __post_init__(self): @@ -147,6 +155,10 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): dim=-1, keepdim=True ) + if self.bg_color is not None: + background = features.new_tensor(self.bg_color) + features = torch.lerp(background, features, mask) + return RendererOutput( features=features[..., 0, :], depths=depth, diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index ea6cbb2f..b1e0489f 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -72,6 +72,7 @@ renderer_LSTMRenderer_args: init_depth_noise_std: 0.0005 hidden_size: 16 n_feature_channels: 256 + bg_color: null verbose: false image_feature_extractor_ResNetFeatureExtractor_args: name: resnet34 diff --git a/tests/implicitron/test_srn.py b/tests/implicitron/test_srn.py index cff633fd..367ec41e 100644 --- a/tests/implicitron/test_srn.py +++ b/tests/implicitron/test_srn.py @@ -7,14 +7,14 @@ import unittest import torch +from pytorch3d.implicitron.models.generic_model import GenericModel from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( SRNHyperNetImplicitFunction, SRNImplicitFunction, SRNPixelGenerator, ) -from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper from pytorch3d.implicitron.tools.config import get_default_args -from pytorch3d.renderer import RayBundle +from pytorch3d.renderer import PerspectiveCameras, RayBundle from tests.common_testing import TestCaseMixin _BATCH_SIZE: int = 3 @@ -69,40 +69,23 @@ class TestSRN(TestCaseMixin, unittest.TestCase): ) self.assertIsNone(rays_colors) - def test_srn_hypernet_implicit_function_optim(self): - # Test optimization loop, requiring that the cache is properly - # cleared in new_args_bound - latent_dim_hypernet = 39 - hyper_args = {"latent_dim_hypernet": latent_dim_hypernet} - device = torch.device("cuda:0") - global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) - bundle = self._get_bundle(device=device) - - implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args) - implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args) - implicit_function.to(device) - implicit_function2.to(device) - - wrapper = ImplicitFunctionWrapper(implicit_function) - optimizer = torch.optim.Adam(implicit_function.parameters()) - for _step in range(3): - optimizer.zero_grad() - wrapper.bind_args(global_code=global_code) - rays_densities, _rays_colors = wrapper(bundle) - wrapper.unbind_args() - loss = rays_densities.sum() - loss.backward() - optimizer.step() - - wrapper2 = ImplicitFunctionWrapper(implicit_function) - optimizer2 = torch.optim.Adam(implicit_function2.parameters()) - implicit_function2.load_state_dict(implicit_function.state_dict()) - optimizer2.load_state_dict(optimizer.state_dict()) - for _step in range(3): - optimizer2.zero_grad() - wrapper2.bind_args(global_code=global_code) - rays_densities, _rays_colors = wrapper2(bundle) - wrapper2.unbind_args() - loss = rays_densities.sum() - loss.backward() - optimizer2.step() + def test_lstm(self): + args = get_default_args(GenericModel) + args.render_image_height = 80 + args.render_image_width = 80 + args.implicit_function_class_type = "SRNImplicitFunction" + args.renderer_class_type = "LSTMRenderer" + args.raysampler_class_type = "NearFarRaySampler" + args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1 + args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1 + args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2] + gm = GenericModel(**args) + camera = PerspectiveCameras() + gm.forward( + camera=camera, + image_rgb=None, + fg_probability=None, + sequence_name="", + mask_crop=None, + depth_map=None, + )