mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
bg_color for lstm renderer
Summary: Allow specifying a color for non-opaque pixels in LSTMRenderer. Reviewed By: davnov134 Differential Revision: D37172537 fbshipit-source-id: 6039726678bb7947f7d8cd04035b5023b2d5398c
This commit is contained in:
parent
65f667fd2e
commit
cba26506b6
@ -85,6 +85,7 @@ generic_model_args:
|
|||||||
init_depth_noise_std: 0.0005
|
init_depth_noise_std: 0.0005
|
||||||
hidden_size: 16
|
hidden_size: 16
|
||||||
n_feature_channels: 256
|
n_feature_channels: 256
|
||||||
|
bg_color: null
|
||||||
verbose: false
|
verbose: false
|
||||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
raymarcher_class_type: EmissionAbsorptionRaymarcher
|
raymarcher_class_type: EmissionAbsorptionRaymarcher
|
||||||
|
@ -21,6 +21,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Implements the learnable LSTM raymarching function from SRN [1].
|
Implements the learnable LSTM raymarching function from SRN [1].
|
||||||
|
This requires there to be one implicit function, and it is expected to be
|
||||||
|
like SRNImplicitFunction or SRNHyperNetImplicitFunction.
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
num_raymarch_steps: The number of LSTM raymarching steps.
|
num_raymarch_steps: The number of LSTM raymarching steps.
|
||||||
@ -32,6 +34,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
hidden_size: The dimensionality of the LSTM's hidden state.
|
hidden_size: The dimensionality of the LSTM's hidden state.
|
||||||
n_feature_channels: The number of feature channels returned by the
|
n_feature_channels: The number of feature channels returned by the
|
||||||
implicit_function evaluated at each raymarching step.
|
implicit_function evaluated at each raymarching step.
|
||||||
|
bg_color: If supplied, used as the background color. Otherwise the pixel
|
||||||
|
generator is used everywhere. This has to have length either 1
|
||||||
|
(for a constant value for all output channels) or equal to the number
|
||||||
|
of output channels (which is `out_features` on the pixel generator,
|
||||||
|
typically 3.)
|
||||||
verbose: If `True`, logs raymarching debug info.
|
verbose: If `True`, logs raymarching debug info.
|
||||||
|
|
||||||
References:
|
References:
|
||||||
@ -45,6 +52,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
init_depth_noise_std: float = 5e-4
|
init_depth_noise_std: float = 5e-4
|
||||||
hidden_size: int = 16
|
hidden_size: int = 16
|
||||||
n_feature_channels: int = 256
|
n_feature_channels: int = 256
|
||||||
|
bg_color: Optional[List[float]] = None
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -147,6 +155,10 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
dim=-1, keepdim=True
|
dim=-1, keepdim=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.bg_color is not None:
|
||||||
|
background = features.new_tensor(self.bg_color)
|
||||||
|
features = torch.lerp(background, features, mask)
|
||||||
|
|
||||||
return RendererOutput(
|
return RendererOutput(
|
||||||
features=features[..., 0, :],
|
features=features[..., 0, :],
|
||||||
depths=depth,
|
depths=depth,
|
||||||
|
@ -72,6 +72,7 @@ renderer_LSTMRenderer_args:
|
|||||||
init_depth_noise_std: 0.0005
|
init_depth_noise_std: 0.0005
|
||||||
hidden_size: 16
|
hidden_size: 16
|
||||||
n_feature_channels: 256
|
n_feature_channels: 256
|
||||||
|
bg_color: null
|
||||||
verbose: false
|
verbose: false
|
||||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
name: resnet34
|
name: resnet34
|
||||||
|
@ -7,14 +7,14 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
|
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
|
||||||
SRNHyperNetImplicitFunction,
|
SRNHyperNetImplicitFunction,
|
||||||
SRNImplicitFunction,
|
SRNImplicitFunction,
|
||||||
SRNPixelGenerator,
|
SRNPixelGenerator,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
|
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
from pytorch3d.renderer import RayBundle
|
from pytorch3d.renderer import PerspectiveCameras, RayBundle
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
_BATCH_SIZE: int = 3
|
_BATCH_SIZE: int = 3
|
||||||
@ -69,40 +69,23 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertIsNone(rays_colors)
|
self.assertIsNone(rays_colors)
|
||||||
|
|
||||||
def test_srn_hypernet_implicit_function_optim(self):
|
def test_lstm(self):
|
||||||
# Test optimization loop, requiring that the cache is properly
|
args = get_default_args(GenericModel)
|
||||||
# cleared in new_args_bound
|
args.render_image_height = 80
|
||||||
latent_dim_hypernet = 39
|
args.render_image_width = 80
|
||||||
hyper_args = {"latent_dim_hypernet": latent_dim_hypernet}
|
args.implicit_function_class_type = "SRNImplicitFunction"
|
||||||
device = torch.device("cuda:0")
|
args.renderer_class_type = "LSTMRenderer"
|
||||||
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
args.raysampler_class_type = "NearFarRaySampler"
|
||||||
bundle = self._get_bundle(device=device)
|
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1
|
||||||
|
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1
|
||||||
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
|
args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2]
|
||||||
implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
|
gm = GenericModel(**args)
|
||||||
implicit_function.to(device)
|
camera = PerspectiveCameras()
|
||||||
implicit_function2.to(device)
|
gm.forward(
|
||||||
|
camera=camera,
|
||||||
wrapper = ImplicitFunctionWrapper(implicit_function)
|
image_rgb=None,
|
||||||
optimizer = torch.optim.Adam(implicit_function.parameters())
|
fg_probability=None,
|
||||||
for _step in range(3):
|
sequence_name="",
|
||||||
optimizer.zero_grad()
|
mask_crop=None,
|
||||||
wrapper.bind_args(global_code=global_code)
|
depth_map=None,
|
||||||
rays_densities, _rays_colors = wrapper(bundle)
|
)
|
||||||
wrapper.unbind_args()
|
|
||||||
loss = rays_densities.sum()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
wrapper2 = ImplicitFunctionWrapper(implicit_function)
|
|
||||||
optimizer2 = torch.optim.Adam(implicit_function2.parameters())
|
|
||||||
implicit_function2.load_state_dict(implicit_function.state_dict())
|
|
||||||
optimizer2.load_state_dict(optimizer.state_dict())
|
|
||||||
for _step in range(3):
|
|
||||||
optimizer2.zero_grad()
|
|
||||||
wrapper2.bind_args(global_code=global_code)
|
|
||||||
rays_densities, _rays_colors = wrapper2(bundle)
|
|
||||||
wrapper2.unbind_args()
|
|
||||||
loss = rays_densities.sum()
|
|
||||||
loss.backward()
|
|
||||||
optimizer2.step()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user