From 54c75b41141668c54ae9a4f4a034548234661a9e Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 21 Jul 2022 15:10:24 -0700 Subject: [PATCH] GM error for unbatched inputs Summary: Error when sending an unbatched FrameData through GM. Reviewed By: shapovalov Differential Revision: D38036286 fbshipit-source-id: b8d280c61fbbefdc112c57ccd630ab3ccce7b44e --- pytorch3d/implicitron/models/generic_model.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 9952a67d..a0aa96d7 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -765,6 +765,17 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 Returns: Modified image_rgb, fg_mask, depth_map """ + if image_rgb is not None and image_rgb.ndim == 3: + # The FrameData object is used for both frames and batches of frames, + # and a user might get this error if those were confused. + # Perhaps a user has a FrameData `fd` representing a single frame and + # wrote something like `model(**fd)` instead of + # `model(**fd.collate([fd]))`. + raise ValueError( + "Model received unbatched inputs. " + + "Perhaps they came from a FrameData which had not been collated." + ) + fg_mask = fg_probability if fg_mask is not None and self.mask_threshold > 0.0: # threshold masks