diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index 34901670..a92cb8f1 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -121,6 +121,19 @@ class ImplicitronRayBundle: else: self._lengths = value + def float_(self) -> None: + """Moves the tensors to float dtype in place + (helpful for mixed-precision tensors). + """ + self.origins = self.origins.float() + self.directions = self.directions.float() + self._lengths = self._lengths.float() if self._lengths is not None else None + self.xys = self.xys.float() + self.bins = self.bins.float() if self.bins is not None else None + self.pixel_radii_2d = ( + self.pixel_radii_2d.float() if self.pixel_radii_2d is not None else None + ) + def is_packed(self) -> bool: """ Returns whether the ImplicitronRayBundle carries data in packed state