object_mask only if required

Summary: New function to check if a renderer needs the object mask.

Reviewed By: davnov134

Differential Revision: D35254009

fbshipit-source-id: 4c99e8a1c0f6641d910eb32bfd6cfae9d3463d50
This commit is contained in:
Jeremy Reizenstein 2022-04-26 08:01:45 -07:00 committed by Facebook GitHub Bot
parent 2edb93d184
commit 9320100abc
3 changed files with 11 additions and 2 deletions

View File

@ -397,7 +397,7 @@ class GenericModel(Configurable, torch.nn.Module):
func.bind_args(**custom_args)
chunked_renderer_inputs = {}
if fg_probability is not None:
if fg_probability is not None and self.renderer.requires_object_mask():
sampled_fb_prob = rend_utils.ndc_grid_sample(
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
)

View File

@ -72,9 +72,15 @@ class BaseRenderer(ABC, ReplaceableBase):
Base class for all Renderer implementations.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def requires_object_mask(self) -> bool:
"""
Whether `forward` needs the object_mask.
"""
return False
@abstractmethod
def forward(
self,

View File

@ -49,6 +49,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
def requires_object_mask(self) -> bool:
return True
def forward(
self,
ray_bundle: RayBundle,