From 2edb93d184af275b69b8f7d824ded66e1a204bd1 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 26 Apr 2022 08:01:45 -0700 Subject: [PATCH] chunked_inputs Summary: Make method for SDF's use of object mask more general, so that a renderer can be given per-pixel values. Reviewed By: shapovalov Differential Revision: D35247412 fbshipit-source-id: 6aeccb1d0b5f1265a3f692a1453407a07e51a33c --- pytorch3d/implicitron/models/base.py | 30 +++++++++--------- tests/implicitron/test_forward_pass.py | 43 ++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/pytorch3d/implicitron/models/base.py b/pytorch3d/implicitron/models/base.py index fdb087d4..a2b303d3 100644 --- a/pytorch3d/implicitron/models/base.py +++ b/pytorch3d/implicitron/models/base.py @@ -396,12 +396,12 @@ class GenericModel(Configurable, torch.nn.Module): for func in self._implicit_functions: func.bind_args(**custom_args) - object_mask: Optional[torch.Tensor] = None + chunked_renderer_inputs = {} if fg_probability is not None: sampled_fb_prob = rend_utils.ndc_grid_sample( fg_probability[:n_targets], ray_bundle.xys, mode="nearest" ) - object_mask = sampled_fb_prob > 0.5 + chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5 # (5)-(6) Implicit function evaluation and Rendering rendered = self._render( @@ -409,7 +409,7 @@ class GenericModel(Configurable, torch.nn.Module): sampling_mode=sampling_mode, evaluation_mode=evaluation_mode, implicit_functions=self._implicit_functions, - object_mask=object_mask, + chunked_inputs=chunked_renderer_inputs, ) # Unbind the custom arguments to prevent pytorch from storing @@ -501,7 +501,6 @@ class GenericModel(Configurable, torch.nn.Module): Helper function to visualize the predictions generated in the forward pass. - Args: viz: Visdom connection object visdom_env_imgs: name of visdom environment for the images. @@ -521,7 +520,7 @@ class GenericModel(Configurable, torch.nn.Module): self, *, ray_bundle: RayBundle, - object_mask: Optional[torch.Tensor], + chunked_inputs: Dict[str, torch.Tensor], sampling_mode: RenderSamplingMode, **kwargs, ) -> RendererOutput: @@ -529,13 +528,16 @@ class GenericModel(Configurable, torch.nn.Module): Args: ray_bundle: A `RayBundle` object containing the parametrizations of the sampled rendering rays. - object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object - in the image. This is required for the SignedDistanceFunctionRenderer. + chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g. + SignedDistanceFunctionRenderer requires "object_mask", shape + (B, 1, H, W), the silhouette of the object in the image. When + chunking, they are passed to the renderer as shape + `(B, _, chunksize)`. sampling_mode: The sampling method to use. Must be a value from the RenderSamplingMode Enum. + Returns: An instance of RendererOutput - """ if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0: return _apply_chunked( @@ -543,7 +545,7 @@ class GenericModel(Configurable, torch.nn.Module): _chunk_generator( self.chunk_size_grid, ray_bundle, - object_mask, + chunked_inputs, self.tqdm_trigger_threshold, **kwargs, ), @@ -553,7 +555,7 @@ class GenericModel(Configurable, torch.nn.Module): # pyre-fixme[29]: `BaseRenderer` is not a function. return self.renderer( ray_bundle=ray_bundle, - object_mask=object_mask, + **chunked_inputs, **kwargs, ) @@ -837,7 +839,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor: def _chunk_generator( chunk_size: int, ray_bundle: RayBundle, - object_mask: Optional[torch.Tensor], + chunked_inputs: Dict[str, torch.Tensor], tqdm_trigger_threshold: int, *args, **kwargs, @@ -880,8 +882,6 @@ def _chunk_generator( xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], ) extra_args = kwargs.copy() - if object_mask is not None: - extra_args["object_mask"] = object_mask.reshape(batch_size, -1, 1)[ - :, start_idx:end_idx - ] + for k, v in chunked_inputs.items(): + extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx] yield [ray_bundle_chunk, *args], extra_args diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py index c7406a6b..e1909248 100644 --- a/tests/implicitron/test_forward_pass.py +++ b/tests/implicitron/test_forward_pass.py @@ -9,13 +9,13 @@ import unittest import torch from pytorch3d.implicitron.models.base import GenericModel from pytorch3d.implicitron.models.renderer.base import EvaluationMode -from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras class TestGenericModel(unittest.TestCase): def test_gm(self): - # Simple test of a forward pass of the default GenericModel. + # Simple test of a forward and backward pass of the default GenericModel. device = torch.device("cuda:1") expand_args_fields(GenericModel) model = GenericModel() @@ -51,6 +51,7 @@ class TestGenericModel(unittest.TestCase): **defaulted_args, ) self.assertGreater(train_preds["objective"].item(), 0) + train_preds["objective"].backward() model.eval() with torch.no_grad(): @@ -65,3 +66,41 @@ class TestGenericModel(unittest.TestCase): eval_preds["images_render"].shape, (1, 3, model.render_image_height, model.render_image_width), ) + + def test_idr(self): + # Forward pass of GenericModel with IDR. + device = torch.device("cuda:1") + args = get_default_args(GenericModel) + args.renderer_class_type = "SignedDistanceFunctionRenderer" + args.implicit_function_class_type = "IdrFeatureField" + args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6 + + model = GenericModel(**args) + model.to(device) + + n_train_cameras = 2 + R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) + cameras = PerspectiveCameras(R=R, T=T, device=device) + + defaulted_args = { + "depth_map": None, + "mask_crop": None, + "sequence_name": None, + } + + target_image_rgb = torch.rand( + (n_train_cameras, 3, model.render_image_height, model.render_image_width), + device=device, + ) + fg_probability = torch.rand( + (n_train_cameras, 1, model.render_image_height, model.render_image_width), + device=device, + ) + train_preds = model( + camera=cameras, + evaluation_mode=EvaluationMode.TRAINING, + image_rgb=target_image_rgb, + fg_probability=fg_probability, + **defaulted_args, + ) + self.assertGreater(train_preds["objective"].item(), 0)