mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
object_mask only if required
Summary: New function to check if a renderer needs the object mask. Reviewed By: davnov134 Differential Revision: D35254009 fbshipit-source-id: 4c99e8a1c0f6641d910eb32bfd6cfae9d3463d50
This commit is contained in:
parent
2edb93d184
commit
9320100abc
@ -397,7 +397,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
func.bind_args(**custom_args)
|
||||
|
||||
chunked_renderer_inputs = {}
|
||||
if fg_probability is not None:
|
||||
if fg_probability is not None and self.renderer.requires_object_mask():
|
||||
sampled_fb_prob = rend_utils.ndc_grid_sample(
|
||||
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
|
||||
)
|
||||
|
@ -72,9 +72,15 @@ class BaseRenderer(ABC, ReplaceableBase):
|
||||
Base class for all Renderer implementations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def requires_object_mask(self) -> bool:
|
||||
"""
|
||||
Whether `forward` needs the object_mask.
|
||||
"""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
|
@ -49,6 +49,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
|
||||
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
|
||||
|
||||
def requires_object_mask(self) -> bool:
|
||||
return True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle: RayBundle,
|
||||
|
Loading…
x
Reference in New Issue
Block a user