Fixing type hints in FrameData

Summary: As subj

Reviewed By: bottler

Differential Revision: D67791200

fbshipit-source-id: c2db01c94718102618f4c8bc5c5130c65ee1d81f
This commit is contained in:
Roman Shapovalov 2025-01-06 04:17:57 -08:00 committed by Facebook GitHub Bot
parent e41aff47db
commit 5247f6ad74

View File

@ -26,7 +26,7 @@ from typing import (
import numpy as np
import torch
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset import orm_types, types
from pytorch3d.implicitron.dataset.utils import (
adjust_camera_to_bbox_crop_,
adjust_camera_to_image_scale_,
@ -50,6 +50,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
@dataclass
class FrameData(Mapping[str, Any]):
@ -454,8 +457,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
@abstractmethod
def build(
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
frame_annotation: FrameAnnotationT,
sequence_annotation: SequenceAnnotationT,
*,
load_blobs: bool = True,
**kwargs,
@ -541,8 +544,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def build(
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
frame_annotation: FrameAnnotationT,
sequence_annotation: SequenceAnnotationT,
*,
load_blobs: bool = True,
**kwargs,
@ -620,11 +623,12 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
image_np, frame_annotation.image.size, frame_data.fg_probability
)
depth_annotation = frame_annotation.depth
if (
load_blobs
and self.load_depths
and frame_annotation.depth is not None
and frame_annotation.depth.path is not None
and depth_annotation is not None
and depth_annotation.path is not None
):
(
frame_data.depth_map,
@ -653,11 +657,10 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
return frame_data
def _load_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[np.ndarray, str]:
assert self.dataset_root is not None and entry.mask is not None
full_path = os.path.join(self.dataset_root, entry.mask.path)
def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]:
mask_annotation = entry.mask
assert self.dataset_root is not None and mask_annotation is not None
full_path = os.path.join(self.dataset_root, mask_annotation.path)
fg_probability = load_mask(self._local_path(full_path))
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
@ -685,7 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_mask_depth(
self,
entry: types.FrameAnnotation,
entry: FrameAnnotationT,
fg_mask: Optional[np.ndarray],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
@ -710,7 +713,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _get_pytorch3d_camera(
self,
entry: types.FrameAnnotation,
entry: FrameAnnotationT,
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None