mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
chunked_inputs
Summary: Make method for SDF's use of object mask more general, so that a renderer can be given per-pixel values. Reviewed By: shapovalov Differential Revision: D35247412 fbshipit-source-id: 6aeccb1d0b5f1265a3f692a1453407a07e51a33c
This commit is contained in:
parent
41c594ca37
commit
2edb93d184
@ -396,12 +396,12 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.bind_args(**custom_args)
|
func.bind_args(**custom_args)
|
||||||
|
|
||||||
object_mask: Optional[torch.Tensor] = None
|
chunked_renderer_inputs = {}
|
||||||
if fg_probability is not None:
|
if fg_probability is not None:
|
||||||
sampled_fb_prob = rend_utils.ndc_grid_sample(
|
sampled_fb_prob = rend_utils.ndc_grid_sample(
|
||||||
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
|
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
|
||||||
)
|
)
|
||||||
object_mask = sampled_fb_prob > 0.5
|
chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5
|
||||||
|
|
||||||
# (5)-(6) Implicit function evaluation and Rendering
|
# (5)-(6) Implicit function evaluation and Rendering
|
||||||
rendered = self._render(
|
rendered = self._render(
|
||||||
@ -409,7 +409,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
sampling_mode=sampling_mode,
|
sampling_mode=sampling_mode,
|
||||||
evaluation_mode=evaluation_mode,
|
evaluation_mode=evaluation_mode,
|
||||||
implicit_functions=self._implicit_functions,
|
implicit_functions=self._implicit_functions,
|
||||||
object_mask=object_mask,
|
chunked_inputs=chunked_renderer_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unbind the custom arguments to prevent pytorch from storing
|
# Unbind the custom arguments to prevent pytorch from storing
|
||||||
@ -501,7 +501,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
Helper function to visualize the predictions generated
|
Helper function to visualize the predictions generated
|
||||||
in the forward pass.
|
in the forward pass.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
viz: Visdom connection object
|
viz: Visdom connection object
|
||||||
visdom_env_imgs: name of visdom environment for the images.
|
visdom_env_imgs: name of visdom environment for the images.
|
||||||
@ -521,7 +520,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
object_mask: Optional[torch.Tensor],
|
chunked_inputs: Dict[str, torch.Tensor],
|
||||||
sampling_mode: RenderSamplingMode,
|
sampling_mode: RenderSamplingMode,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
@ -529,13 +528,16 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object
|
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
|
||||||
in the image. This is required for the SignedDistanceFunctionRenderer.
|
SignedDistanceFunctionRenderer requires "object_mask", shape
|
||||||
|
(B, 1, H, W), the silhouette of the object in the image. When
|
||||||
|
chunking, they are passed to the renderer as shape
|
||||||
|
`(B, _, chunksize)`.
|
||||||
sampling_mode: The sampling method to use. Must be a value from the
|
sampling_mode: The sampling method to use. Must be a value from the
|
||||||
RenderSamplingMode Enum.
|
RenderSamplingMode Enum.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of RendererOutput
|
An instance of RendererOutput
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
|
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
|
||||||
return _apply_chunked(
|
return _apply_chunked(
|
||||||
@ -543,7 +545,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
_chunk_generator(
|
_chunk_generator(
|
||||||
self.chunk_size_grid,
|
self.chunk_size_grid,
|
||||||
ray_bundle,
|
ray_bundle,
|
||||||
object_mask,
|
chunked_inputs,
|
||||||
self.tqdm_trigger_threshold,
|
self.tqdm_trigger_threshold,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
),
|
),
|
||||||
@ -553,7 +555,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
# pyre-fixme[29]: `BaseRenderer` is not a function.
|
# pyre-fixme[29]: `BaseRenderer` is not a function.
|
||||||
return self.renderer(
|
return self.renderer(
|
||||||
ray_bundle=ray_bundle,
|
ray_bundle=ray_bundle,
|
||||||
object_mask=object_mask,
|
**chunked_inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -837,7 +839,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:
|
|||||||
def _chunk_generator(
|
def _chunk_generator(
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
object_mask: Optional[torch.Tensor],
|
chunked_inputs: Dict[str, torch.Tensor],
|
||||||
tqdm_trigger_threshold: int,
|
tqdm_trigger_threshold: int,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -880,8 +882,6 @@ def _chunk_generator(
|
|||||||
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
||||||
)
|
)
|
||||||
extra_args = kwargs.copy()
|
extra_args = kwargs.copy()
|
||||||
if object_mask is not None:
|
for k, v in chunked_inputs.items():
|
||||||
extra_args["object_mask"] = object_mask.reshape(batch_size, -1, 1)[
|
extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
|
||||||
:, start_idx:end_idx
|
|
||||||
]
|
|
||||||
yield [ray_bundle_chunk, *args], extra_args
|
yield [ray_bundle_chunk, *args], extra_args
|
||||||
|
@ -9,13 +9,13 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.base import GenericModel
|
from pytorch3d.implicitron.models.base import GenericModel
|
||||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||||
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
||||||
|
|
||||||
|
|
||||||
class TestGenericModel(unittest.TestCase):
|
class TestGenericModel(unittest.TestCase):
|
||||||
def test_gm(self):
|
def test_gm(self):
|
||||||
# Simple test of a forward pass of the default GenericModel.
|
# Simple test of a forward and backward pass of the default GenericModel.
|
||||||
device = torch.device("cuda:1")
|
device = torch.device("cuda:1")
|
||||||
expand_args_fields(GenericModel)
|
expand_args_fields(GenericModel)
|
||||||
model = GenericModel()
|
model = GenericModel()
|
||||||
@ -51,6 +51,7 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
**defaulted_args,
|
**defaulted_args,
|
||||||
)
|
)
|
||||||
self.assertGreater(train_preds["objective"].item(), 0)
|
self.assertGreater(train_preds["objective"].item(), 0)
|
||||||
|
train_preds["objective"].backward()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -65,3 +66,41 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
eval_preds["images_render"].shape,
|
eval_preds["images_render"].shape,
|
||||||
(1, 3, model.render_image_height, model.render_image_width),
|
(1, 3, model.render_image_height, model.render_image_width),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_idr(self):
|
||||||
|
# Forward pass of GenericModel with IDR.
|
||||||
|
device = torch.device("cuda:1")
|
||||||
|
args = get_default_args(GenericModel)
|
||||||
|
args.renderer_class_type = "SignedDistanceFunctionRenderer"
|
||||||
|
args.implicit_function_class_type = "IdrFeatureField"
|
||||||
|
args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6
|
||||||
|
|
||||||
|
model = GenericModel(**args)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
n_train_cameras = 2
|
||||||
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||||
|
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
||||||
|
|
||||||
|
defaulted_args = {
|
||||||
|
"depth_map": None,
|
||||||
|
"mask_crop": None,
|
||||||
|
"sequence_name": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
target_image_rgb = torch.rand(
|
||||||
|
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
fg_probability = torch.rand(
|
||||||
|
(n_train_cameras, 1, model.render_image_height, model.render_image_width),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
train_preds = model(
|
||||||
|
camera=cameras,
|
||||||
|
evaluation_mode=EvaluationMode.TRAINING,
|
||||||
|
image_rgb=target_image_rgb,
|
||||||
|
fg_probability=fg_probability,
|
||||||
|
**defaulted_args,
|
||||||
|
)
|
||||||
|
self.assertGreater(train_preds["objective"].item(), 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user