diff --git a/pytorch3d/implicitron/models/base.py b/pytorch3d/implicitron/models/base.py index a2b303d3..4b5d6392 100644 --- a/pytorch3d/implicitron/models/base.py +++ b/pytorch3d/implicitron/models/base.py @@ -397,7 +397,7 @@ class GenericModel(Configurable, torch.nn.Module): func.bind_args(**custom_args) chunked_renderer_inputs = {} - if fg_probability is not None: + if fg_probability is not None and self.renderer.requires_object_mask(): sampled_fb_prob = rend_utils.ndc_grid_sample( fg_probability[:n_targets], ray_bundle.xys, mode="nearest" ) diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index f14d7231..3b5f2d83 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -72,9 +72,15 @@ class BaseRenderer(ABC, ReplaceableBase): Base class for all Renderer implementations. """ - def __init__(self): + def __init__(self) -> None: super().__init__() + def requires_object_mask(self) -> bool: + """ + Whether `forward` needs the object_mask. + """ + return False + @abstractmethod def forward( self, diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index e06f42d3..549f4cc8 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -49,6 +49,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False) + def requires_object_mask(self) -> bool: + return True + def forward( self, ray_bundle: RayBundle,