Darijan Gudelj ad8907d373 ImplicitronRayBundle
Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated.

Reviewed By: bottler, kjchalup

Differential Revision: D39262999

fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
2022-10-03 08:36:47 -07:00

122 lines
4.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
SRNPixelGenerator,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import PerspectiveCameras
from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3
_N_RAYS: int = 100
_N_POINTS_ON_RAY: int = 10
class TestSRN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
get_default_args(SRNHyperNetImplicitFunction)
get_default_args(SRNImplicitFunction)
def test_pixel_generator(self):
SRNPixelGenerator()
def _get_bundle(self, *, device) -> ImplicitronRayBundle:
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
bundle = ImplicitronRayBundle(
lengths=lengths,
origins=origins,
directions=directions,
xys=None,
camera_ids=None,
camera_counts=None,
)
return bundle
def test_srn_implicit_function(self):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(ray_bundle=bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function(self):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet = 39
device = torch.device("cuda:0")
implicit_function = SRNHyperNetImplicitFunction(
latent_dim_hypernet=latent_dim_hypernet
)
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(
ray_bundle=bundle, global_code=global_code
)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
@torch.no_grad()
def test_lstm(self):
args = get_default_args(GenericModel)
args.render_image_height = 80
args.render_image_width = 80
args.implicit_function_class_type = "SRNImplicitFunction"
args.renderer_class_type = "LSTMRenderer"
args.raysampler_class_type = "NearFarRaySampler"
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1
args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2]
gm = GenericModel(**args)
camera = PerspectiveCameras()
image = gm.forward(
camera=camera,
image_rgb=None,
fg_probability=None,
sequence_name="",
mask_crop=None,
depth_map=None,
)["images_render"]
self.assertEqual(image.shape, (1, 3, 80, 80))
self.assertGreater(image.max(), 0.8)
# Force everything to be background
pixel_generator = gm._implicit_functions[0]._fn.pixel_generator
pixel_generator._density_layer.weight.zero_()
pixel_generator._density_layer.bias.fill_(-1.0e6)
image = gm.forward(
camera=camera,
image_rgb=None,
fg_probability=None,
sequence_name="",
mask_crop=None,
depth_map=None,
)["images_render"]
self.assertConstant(image[:, :2], 0.4)
self.assertConstant(image[:, 2], 0.2)