mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Fixing type hints in FrameData
Summary: As subj Reviewed By: bottler Differential Revision: D67791200 fbshipit-source-id: c2db01c94718102618f4c8bc5c5130c65ee1d81f
This commit is contained in:
parent
e41aff47db
commit
5247f6ad74
@ -26,7 +26,7 @@ from typing import (
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset import orm_types, types
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
adjust_camera_to_bbox_crop_,
|
||||
adjust_camera_to_image_scale_,
|
||||
@ -50,6 +50,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
@ -454,8 +457,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@ -541,8 +544,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@ -620,11 +623,12 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
depth_annotation = frame_annotation.depth
|
||||
if (
|
||||
load_blobs
|
||||
and self.load_depths
|
||||
and frame_annotation.depth is not None
|
||||
and frame_annotation.depth.path is not None
|
||||
and depth_annotation is not None
|
||||
and depth_annotation.path is not None
|
||||
):
|
||||
(
|
||||
frame_data.depth_map,
|
||||
@ -653,11 +657,10 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
return frame_data
|
||||
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
assert self.dataset_root is not None and entry.mask is not None
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]:
|
||||
mask_annotation = entry.mask
|
||||
assert self.dataset_root is not None and mask_annotation is not None
|
||||
full_path = os.path.join(self.dataset_root, mask_annotation.path)
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
@ -685,7 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
fg_mask: Optional[np.ndarray],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
@ -710,7 +713,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _get_pytorch3d_camera(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
) -> PerspectiveCameras:
|
||||
entry_viewpoint = entry.viewpoint
|
||||
assert entry_viewpoint is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user