From 5247f6ad7475093dcbde31354a6fab8ae0a5d15d Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Mon, 6 Jan 2025 04:17:57 -0800 Subject: [PATCH] Fixing type hints in FrameData Summary: As subj Reviewed By: bottler Differential Revision: D67791200 fbshipit-source-id: c2db01c94718102618f4c8bc5c5130c65ee1d81f --- pytorch3d/implicitron/dataset/frame_data.py | 31 +++++++++++---------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index 9cd7e75f..3d4d6167 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -26,7 +26,7 @@ from typing import ( import numpy as np import torch -from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset import orm_types, types from pytorch3d.implicitron.dataset.utils import ( adjust_camera_to_bbox_crop_, adjust_camera_to_image_scale_, @@ -50,6 +50,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds +FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation +SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation + @dataclass class FrameData(Mapping[str, Any]): @@ -454,8 +457,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC): @abstractmethod def build( self, - frame_annotation: types.FrameAnnotation, - sequence_annotation: types.SequenceAnnotation, + frame_annotation: FrameAnnotationT, + sequence_annotation: SequenceAnnotationT, *, load_blobs: bool = True, **kwargs, @@ -541,8 +544,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): def build( self, - frame_annotation: types.FrameAnnotation, - sequence_annotation: types.SequenceAnnotation, + frame_annotation: FrameAnnotationT, + sequence_annotation: SequenceAnnotationT, *, load_blobs: bool = True, **kwargs, @@ -620,11 +623,12 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): image_np, frame_annotation.image.size, frame_data.fg_probability ) + depth_annotation = frame_annotation.depth if ( load_blobs and self.load_depths - and frame_annotation.depth is not None - and frame_annotation.depth.path is not None + and depth_annotation is not None + and depth_annotation.path is not None ): ( frame_data.depth_map, @@ -653,11 +657,10 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): return frame_data - def _load_fg_probability( - self, entry: types.FrameAnnotation - ) -> Tuple[np.ndarray, str]: - assert self.dataset_root is not None and entry.mask is not None - full_path = os.path.join(self.dataset_root, entry.mask.path) + def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]: + mask_annotation = entry.mask + assert self.dataset_root is not None and mask_annotation is not None + full_path = os.path.join(self.dataset_root, mask_annotation.path) fg_probability = load_mask(self._local_path(full_path)) if fg_probability.shape[-2:] != entry.image.size: raise ValueError( @@ -685,7 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): def _load_mask_depth( self, - entry: types.FrameAnnotation, + entry: FrameAnnotationT, fg_mask: Optional[np.ndarray], ) -> Tuple[torch.Tensor, str, torch.Tensor]: entry_depth = entry.depth @@ -710,7 +713,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): def _get_pytorch3d_camera( self, - entry: types.FrameAnnotation, + entry: FrameAnnotationT, ) -> PerspectiveCameras: entry_viewpoint = entry.viewpoint assert entry_viewpoint is not None