Fixing type hints in FrameData

Summary: As subj

Reviewed By: bottler

Differential Revision: D67791200

fbshipit-source-id: c2db01c94718102618f4c8bc5c5130c65ee1d81f
This commit is contained in:
Roman Shapovalov 2025-01-06 04:17:57 -08:00 committed by Facebook GitHub Bot
parent e41aff47db
commit 5247f6ad74

View File

@ -26,7 +26,7 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
from pytorch3d.implicitron.dataset import types from pytorch3d.implicitron.dataset import orm_types, types
from pytorch3d.implicitron.dataset.utils import ( from pytorch3d.implicitron.dataset.utils import (
adjust_camera_to_bbox_crop_, adjust_camera_to_bbox_crop_,
adjust_camera_to_image_scale_, adjust_camera_to_image_scale_,
@ -50,6 +50,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
@dataclass @dataclass
class FrameData(Mapping[str, Any]): class FrameData(Mapping[str, Any]):
@ -454,8 +457,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
@abstractmethod @abstractmethod
def build( def build(
self, self,
frame_annotation: types.FrameAnnotation, frame_annotation: FrameAnnotationT,
sequence_annotation: types.SequenceAnnotation, sequence_annotation: SequenceAnnotationT,
*, *,
load_blobs: bool = True, load_blobs: bool = True,
**kwargs, **kwargs,
@ -541,8 +544,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def build( def build(
self, self,
frame_annotation: types.FrameAnnotation, frame_annotation: FrameAnnotationT,
sequence_annotation: types.SequenceAnnotation, sequence_annotation: SequenceAnnotationT,
*, *,
load_blobs: bool = True, load_blobs: bool = True,
**kwargs, **kwargs,
@ -620,11 +623,12 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
image_np, frame_annotation.image.size, frame_data.fg_probability image_np, frame_annotation.image.size, frame_data.fg_probability
) )
depth_annotation = frame_annotation.depth
if ( if (
load_blobs load_blobs
and self.load_depths and self.load_depths
and frame_annotation.depth is not None and depth_annotation is not None
and frame_annotation.depth.path is not None and depth_annotation.path is not None
): ):
( (
frame_data.depth_map, frame_data.depth_map,
@ -653,11 +657,10 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
return frame_data return frame_data
def _load_fg_probability( def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]:
self, entry: types.FrameAnnotation mask_annotation = entry.mask
) -> Tuple[np.ndarray, str]: assert self.dataset_root is not None and mask_annotation is not None
assert self.dataset_root is not None and entry.mask is not None full_path = os.path.join(self.dataset_root, mask_annotation.path)
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = load_mask(self._local_path(full_path)) fg_probability = load_mask(self._local_path(full_path))
if fg_probability.shape[-2:] != entry.image.size: if fg_probability.shape[-2:] != entry.image.size:
raise ValueError( raise ValueError(
@ -685,7 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_mask_depth( def _load_mask_depth(
self, self,
entry: types.FrameAnnotation, entry: FrameAnnotationT,
fg_mask: Optional[np.ndarray], fg_mask: Optional[np.ndarray],
) -> Tuple[torch.Tensor, str, torch.Tensor]: ) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth entry_depth = entry.depth
@ -710,7 +713,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _get_pytorch3d_camera( def _get_pytorch3d_camera(
self, self,
entry: types.FrameAnnotation, entry: FrameAnnotationT,
) -> PerspectiveCameras: ) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None assert entry_viewpoint is not None