srn/idr followups

Summary: small followup to D37172537 (cba26506b6) and D37209012 (81d63c6382): changing default #harmonics and improving a test

Reviewed By: shapovalov

Differential Revision: D37412357

fbshipit-source-id: 1af1005a129425fd24fa6dd213d69c71632099a0
This commit is contained in:
Jeremy Reizenstein 2022-06-24 04:07:15 -07:00 committed by Facebook GitHub Bot
parent 3e4fb0b9d9
commit 5c1ca757bb
3 changed files with 24 additions and 4 deletions

View File

@ -201,7 +201,7 @@ generic_model_args:
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: -1
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
encoding_dim: 0
implicit_function_NeRFormerImplicitFunction_args:

View File

@ -58,7 +58,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
bias: float = 1.0
skip_in: Sequence[int] = ()
weight_norm: bool = True
n_harmonic_functions_xyz: int = -1
n_harmonic_functions_xyz: int = 0
pooled_feature_dim: int = 0
encoding_dim: int = 0

View File

@ -69,6 +69,7 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
)
self.assertIsNone(rays_colors)
@torch.no_grad()
def test_lstm(self):
args = get_default_args(GenericModel)
args.render_image_height = 80
@ -80,12 +81,31 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1
args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2]
gm = GenericModel(**args)
camera = PerspectiveCameras()
gm.forward(
image = gm.forward(
camera=camera,
image_rgb=None,
fg_probability=None,
sequence_name="",
mask_crop=None,
depth_map=None,
)
)["images_render"]
self.assertEqual(image.shape, (1, 3, 80, 80))
self.assertGreater(image.max(), 0.8)
# Force everything to be background
pixel_generator = gm._implicit_functions[0]._fn.pixel_generator
pixel_generator._density_layer.weight.zero_()
pixel_generator._density_layer.bias.fill_(-1.0e6)
image = gm.forward(
camera=camera,
image_rgb=None,
fg_probability=None,
sequence_name="",
mask_crop=None,
depth_map=None,
)["images_render"]
self.assertConstant(image[:, :2], 0.4)
self.assertConstant(image[:, 2], 0.2)