From 9320100abc77f57ae47d2d3961cf039ccd5472d2 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 26 Apr 2022 08:01:45 -0700 Subject: [PATCH] object_mask only if required Summary: New function to check if a renderer needs the object mask. Reviewed By: davnov134 Differential Revision: D35254009 fbshipit-source-id: 4c99e8a1c0f6641d910eb32bfd6cfae9d3463d50 --- pytorch3d/implicitron/models/base.py | 2 +- pytorch3d/implicitron/models/renderer/base.py | 8 +++++++- pytorch3d/implicitron/models/renderer/sdf_renderer.py | 3 +++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch3d/implicitron/models/base.py b/pytorch3d/implicitron/models/base.py index a2b303d3..4b5d6392 100644 --- a/pytorch3d/implicitron/models/base.py +++ b/pytorch3d/implicitron/models/base.py @@ -397,7 +397,7 @@ class GenericModel(Configurable, torch.nn.Module): func.bind_args(**custom_args) chunked_renderer_inputs = {} - if fg_probability is not None: + if fg_probability is not None and self.renderer.requires_object_mask(): sampled_fb_prob = rend_utils.ndc_grid_sample( fg_probability[:n_targets], ray_bundle.xys, mode="nearest" ) diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index f14d7231..3b5f2d83 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -72,9 +72,15 @@ class BaseRenderer(ABC, ReplaceableBase): Base class for all Renderer implementations. """ - def __init__(self): + def __init__(self) -> None: super().__init__() + def requires_object_mask(self) -> bool: + """ + Whether `forward` needs the object_mask. + """ + return False + @abstractmethod def forward( self, diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index e06f42d3..549f4cc8 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -49,6 +49,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False) + def requires_object_mask(self) -> bool: + return True + def forward( self, ray_bundle: RayBundle,