mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Summary: Allow specifying a color for non-opaque pixels in LSTMRenderer. Reviewed By: davnov134 Differential Revision: D37172537 fbshipit-source-id: 6039726678bb7947f7d8cd04035b5023b2d5398c
92 lines
3.5 KiB
Python
92 lines
3.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
|
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
|
|
SRNHyperNetImplicitFunction,
|
|
SRNImplicitFunction,
|
|
SRNPixelGenerator,
|
|
)
|
|
from pytorch3d.implicitron.tools.config import get_default_args
|
|
from pytorch3d.renderer import PerspectiveCameras, RayBundle
|
|
from tests.common_testing import TestCaseMixin
|
|
|
|
_BATCH_SIZE: int = 3
|
|
_N_RAYS: int = 100
|
|
_N_POINTS_ON_RAY: int = 10
|
|
|
|
|
|
class TestSRN(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
torch.manual_seed(42)
|
|
get_default_args(SRNHyperNetImplicitFunction)
|
|
get_default_args(SRNImplicitFunction)
|
|
|
|
def test_pixel_generator(self):
|
|
SRNPixelGenerator()
|
|
|
|
def _get_bundle(self, *, device) -> RayBundle:
|
|
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
|
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
|
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
|
|
bundle = RayBundle(
|
|
lengths=lengths, origins=origins, directions=directions, xys=None
|
|
)
|
|
return bundle
|
|
|
|
def test_srn_implicit_function(self):
|
|
implicit_function = SRNImplicitFunction()
|
|
device = torch.device("cpu")
|
|
bundle = self._get_bundle(device=device)
|
|
rays_densities, rays_colors = implicit_function(bundle)
|
|
out_features = implicit_function.raymarch_function.out_features
|
|
self.assertEqual(
|
|
rays_densities.shape,
|
|
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
|
|
)
|
|
self.assertIsNone(rays_colors)
|
|
|
|
def test_srn_hypernet_implicit_function(self):
|
|
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
|
|
latent_dim_hypernet = 39
|
|
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
|
|
device = torch.device("cuda:0")
|
|
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
|
|
implicit_function.to(device)
|
|
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
|
bundle = self._get_bundle(device=device)
|
|
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
|
|
out_features = implicit_function.hypernet.out_features
|
|
self.assertEqual(
|
|
rays_densities.shape,
|
|
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
|
|
)
|
|
self.assertIsNone(rays_colors)
|
|
|
|
def test_lstm(self):
|
|
args = get_default_args(GenericModel)
|
|
args.render_image_height = 80
|
|
args.render_image_width = 80
|
|
args.implicit_function_class_type = "SRNImplicitFunction"
|
|
args.renderer_class_type = "LSTMRenderer"
|
|
args.raysampler_class_type = "NearFarRaySampler"
|
|
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1
|
|
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1
|
|
args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2]
|
|
gm = GenericModel(**args)
|
|
camera = PerspectiveCameras()
|
|
gm.forward(
|
|
camera=camera,
|
|
image_rgb=None,
|
|
fg_probability=None,
|
|
sequence_name="",
|
|
mask_crop=None,
|
|
depth_map=None,
|
|
)
|