mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Fixing type hints in FrameData
Summary: As subj Reviewed By: bottler Differential Revision: D67791200 fbshipit-source-id: c2db01c94718102618f4c8bc5c5130c65ee1d81f
This commit is contained in:
parent
e41aff47db
commit
5247f6ad74
@ -26,7 +26,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset import types
|
from pytorch3d.implicitron.dataset import orm_types, types
|
||||||
from pytorch3d.implicitron.dataset.utils import (
|
from pytorch3d.implicitron.dataset.utils import (
|
||||||
adjust_camera_to_bbox_crop_,
|
adjust_camera_to_bbox_crop_,
|
||||||
adjust_camera_to_image_scale_,
|
adjust_camera_to_image_scale_,
|
||||||
@ -50,6 +50,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
|||||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||||
|
|
||||||
|
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||||
|
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrameData(Mapping[str, Any]):
|
class FrameData(Mapping[str, Any]):
|
||||||
@ -454,8 +457,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
frame_annotation: types.FrameAnnotation,
|
frame_annotation: FrameAnnotationT,
|
||||||
sequence_annotation: types.SequenceAnnotation,
|
sequence_annotation: SequenceAnnotationT,
|
||||||
*,
|
*,
|
||||||
load_blobs: bool = True,
|
load_blobs: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -541,8 +544,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
frame_annotation: types.FrameAnnotation,
|
frame_annotation: FrameAnnotationT,
|
||||||
sequence_annotation: types.SequenceAnnotation,
|
sequence_annotation: SequenceAnnotationT,
|
||||||
*,
|
*,
|
||||||
load_blobs: bool = True,
|
load_blobs: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -620,11 +623,12 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||||
)
|
)
|
||||||
|
|
||||||
|
depth_annotation = frame_annotation.depth
|
||||||
if (
|
if (
|
||||||
load_blobs
|
load_blobs
|
||||||
and self.load_depths
|
and self.load_depths
|
||||||
and frame_annotation.depth is not None
|
and depth_annotation is not None
|
||||||
and frame_annotation.depth.path is not None
|
and depth_annotation.path is not None
|
||||||
):
|
):
|
||||||
(
|
(
|
||||||
frame_data.depth_map,
|
frame_data.depth_map,
|
||||||
@ -653,11 +657,10 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
return frame_data
|
return frame_data
|
||||||
|
|
||||||
def _load_fg_probability(
|
def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]:
|
||||||
self, entry: types.FrameAnnotation
|
mask_annotation = entry.mask
|
||||||
) -> Tuple[np.ndarray, str]:
|
assert self.dataset_root is not None and mask_annotation is not None
|
||||||
assert self.dataset_root is not None and entry.mask is not None
|
full_path = os.path.join(self.dataset_root, mask_annotation.path)
|
||||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
|
||||||
fg_probability = load_mask(self._local_path(full_path))
|
fg_probability = load_mask(self._local_path(full_path))
|
||||||
if fg_probability.shape[-2:] != entry.image.size:
|
if fg_probability.shape[-2:] != entry.image.size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -685,7 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
def _load_mask_depth(
|
def _load_mask_depth(
|
||||||
self,
|
self,
|
||||||
entry: types.FrameAnnotation,
|
entry: FrameAnnotationT,
|
||||||
fg_mask: Optional[np.ndarray],
|
fg_mask: Optional[np.ndarray],
|
||||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||||
entry_depth = entry.depth
|
entry_depth = entry.depth
|
||||||
@ -710,7 +713,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
def _get_pytorch3d_camera(
|
def _get_pytorch3d_camera(
|
||||||
self,
|
self,
|
||||||
entry: types.FrameAnnotation,
|
entry: FrameAnnotationT,
|
||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
entry_viewpoint = entry.viewpoint
|
entry_viewpoint = entry.viewpoint
|
||||||
assert entry_viewpoint is not None
|
assert entry_viewpoint is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user