Darijan Gudelj 72c3a0ebe5 raybundle input to ImplicitFunctions -> api unification
Summary: Currently some implicit functions in implicitron take a raybundle, others take ray_points_world. raybundle is what they really need. However, the raybundle is going to become a bit more flexible later, as it will contain different numbers of rays for each camera.

Reviewed By: bottler

Differential Revision: D39173751

fbshipit-source-id: ebc038e426d22e831e67a18ba64655d8a61e1eb9
2022-09-05 06:26:06 -07:00

115 lines
4.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
SRNPixelGenerator,
)
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import PerspectiveCameras, RayBundle
from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3
_N_RAYS: int = 100
_N_POINTS_ON_RAY: int = 10
class TestSRN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
get_default_args(SRNHyperNetImplicitFunction)
get_default_args(SRNImplicitFunction)
def test_pixel_generator(self):
SRNPixelGenerator()
def _get_bundle(self, *, device) -> RayBundle:
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
bundle = RayBundle(
lengths=lengths, origins=origins, directions=directions, xys=None
)
return bundle
def test_srn_implicit_function(self):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(ray_bundle=bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function(self):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet = 39
device = torch.device("cuda:0")
implicit_function = SRNHyperNetImplicitFunction(
latent_dim_hypernet=latent_dim_hypernet
)
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(
ray_bundle=bundle, global_code=global_code
)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
@torch.no_grad()
def test_lstm(self):
args = get_default_args(GenericModel)
args.render_image_height = 80
args.render_image_width = 80
args.implicit_function_class_type = "SRNImplicitFunction"
args.renderer_class_type = "LSTMRenderer"
args.raysampler_class_type = "NearFarRaySampler"
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1
args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2]
gm = GenericModel(**args)
camera = PerspectiveCameras()
image = gm.forward(
camera=camera,
image_rgb=None,
fg_probability=None,
sequence_name="",
mask_crop=None,
depth_map=None,
)["images_render"]
self.assertEqual(image.shape, (1, 3, 80, 80))
self.assertGreater(image.max(), 0.8)
# Force everything to be background
pixel_generator = gm._implicit_functions[0]._fn.pixel_generator
pixel_generator._density_layer.weight.zero_()
pixel_generator._density_layer.bias.fill_(-1.0e6)
image = gm.forward(
camera=camera,
image_rgb=None,
fg_probability=None,
sequence_name="",
mask_crop=None,
depth_map=None,
)["images_render"]
self.assertConstant(image[:, :2], 0.4)
self.assertConstant(image[:, 2], 0.2)