chunked_inputs

Summary: Make method for SDF's use of object mask more general, so that a renderer can be given per-pixel values.

Reviewed By: shapovalov

Differential Revision: D35247412

fbshipit-source-id: 6aeccb1d0b5f1265a3f692a1453407a07e51a33c
This commit is contained in:
Jeremy Reizenstein 2022-04-26 08:01:45 -07:00 committed by Facebook GitHub Bot
parent 41c594ca37
commit 2edb93d184
2 changed files with 56 additions and 17 deletions

View File

@ -396,12 +396,12 @@ class GenericModel(Configurable, torch.nn.Module):
for func in self._implicit_functions:
func.bind_args(**custom_args)
object_mask: Optional[torch.Tensor] = None
chunked_renderer_inputs = {}
if fg_probability is not None:
sampled_fb_prob = rend_utils.ndc_grid_sample(
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
)
object_mask = sampled_fb_prob > 0.5
chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5
# (5)-(6) Implicit function evaluation and Rendering
rendered = self._render(
@ -409,7 +409,7 @@ class GenericModel(Configurable, torch.nn.Module):
sampling_mode=sampling_mode,
evaluation_mode=evaluation_mode,
implicit_functions=self._implicit_functions,
object_mask=object_mask,
chunked_inputs=chunked_renderer_inputs,
)
# Unbind the custom arguments to prevent pytorch from storing
@ -501,7 +501,6 @@ class GenericModel(Configurable, torch.nn.Module):
Helper function to visualize the predictions generated
in the forward pass.
Args:
viz: Visdom connection object
visdom_env_imgs: name of visdom environment for the images.
@ -521,7 +520,7 @@ class GenericModel(Configurable, torch.nn.Module):
self,
*,
ray_bundle: RayBundle,
object_mask: Optional[torch.Tensor],
chunked_inputs: Dict[str, torch.Tensor],
sampling_mode: RenderSamplingMode,
**kwargs,
) -> RendererOutput:
@ -529,13 +528,16 @@ class GenericModel(Configurable, torch.nn.Module):
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object
in the image. This is required for the SignedDistanceFunctionRenderer.
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
SignedDistanceFunctionRenderer requires "object_mask", shape
(B, 1, H, W), the silhouette of the object in the image. When
chunking, they are passed to the renderer as shape
`(B, _, chunksize)`.
sampling_mode: The sampling method to use. Must be a value from the
RenderSamplingMode Enum.
Returns:
An instance of RendererOutput
"""
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
return _apply_chunked(
@ -543,7 +545,7 @@ class GenericModel(Configurable, torch.nn.Module):
_chunk_generator(
self.chunk_size_grid,
ray_bundle,
object_mask,
chunked_inputs,
self.tqdm_trigger_threshold,
**kwargs,
),
@ -553,7 +555,7 @@ class GenericModel(Configurable, torch.nn.Module):
# pyre-fixme[29]: `BaseRenderer` is not a function.
return self.renderer(
ray_bundle=ray_bundle,
object_mask=object_mask,
**chunked_inputs,
**kwargs,
)
@ -837,7 +839,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:
def _chunk_generator(
chunk_size: int,
ray_bundle: RayBundle,
object_mask: Optional[torch.Tensor],
chunked_inputs: Dict[str, torch.Tensor],
tqdm_trigger_threshold: int,
*args,
**kwargs,
@ -880,8 +882,6 @@ def _chunk_generator(
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
)
extra_args = kwargs.copy()
if object_mask is not None:
extra_args["object_mask"] = object_mask.reshape(batch_size, -1, 1)[
:, start_idx:end_idx
]
for k, v in chunked_inputs.items():
extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
yield [ray_bundle_chunk, *args], extra_args

View File

@ -9,13 +9,13 @@ import unittest
import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
class TestGenericModel(unittest.TestCase):
def test_gm(self):
# Simple test of a forward pass of the default GenericModel.
# Simple test of a forward and backward pass of the default GenericModel.
device = torch.device("cuda:1")
expand_args_fields(GenericModel)
model = GenericModel()
@ -51,6 +51,7 @@ class TestGenericModel(unittest.TestCase):
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
train_preds["objective"].backward()
model.eval()
with torch.no_grad():
@ -65,3 +66,41 @@ class TestGenericModel(unittest.TestCase):
eval_preds["images_render"].shape,
(1, 3, model.render_image_height, model.render_image_width),
)
def test_idr(self):
# Forward pass of GenericModel with IDR.
device = torch.device("cuda:1")
args = get_default_args(GenericModel)
args.renderer_class_type = "SignedDistanceFunctionRenderer"
args.implicit_function_class_type = "IdrFeatureField"
args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6
model = GenericModel(**args)
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
defaulted_args = {
"depth_map": None,
"mask_crop": None,
"sequence_name": None,
}
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
fg_probability = torch.rand(
(n_train_cameras, 1, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
fg_probability=fg_probability,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)