GM error for unbatched inputs

Summary: Error when sending an unbatched FrameData through GM.

Reviewed By: shapovalov

Differential Revision: D38036286

fbshipit-source-id: b8d280c61fbbefdc112c57ccd630ab3ccce7b44e
This commit is contained in:
Jeremy Reizenstein 2022-07-21 15:10:24 -07:00 committed by Facebook GitHub Bot
parent 3783437d2f
commit 54c75b4114

View File

@ -765,6 +765,17 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
Returns:
Modified image_rgb, fg_mask, depth_map
"""
if image_rgb is not None and image_rgb.ndim == 3:
# The FrameData object is used for both frames and batches of frames,
# and a user might get this error if those were confused.
# Perhaps a user has a FrameData `fd` representing a single frame and
# wrote something like `model(**fd)` instead of
# `model(**fd.collate([fd]))`.
raise ValueError(
"Model received unbatched inputs. "
+ "Perhaps they came from a FrameData which had not been collated."
)
fg_mask = fg_probability
if fg_mask is not None and self.mask_threshold > 0.0:
# threshold masks