mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Fixing type hints in FrameData
Summary: As subj Reviewed By: bottler Differential Revision: D67791200 fbshipit-source-id: c2db01c94718102618f4c8bc5c5130c65ee1d81f
This commit is contained in:
		
							parent
							
								
									e41aff47db
								
							
						
					
					
						commit
						5247f6ad74
					
				@ -26,7 +26,7 @@ from typing import (
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset import types
 | 
			
		||||
from pytorch3d.implicitron.dataset import orm_types, types
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import (
 | 
			
		||||
    adjust_camera_to_bbox_crop_,
 | 
			
		||||
    adjust_camera_to_image_scale_,
 | 
			
		||||
@ -50,6 +50,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
 | 
			
		||||
 | 
			
		||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
 | 
			
		||||
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FrameData(Mapping[str, Any]):
 | 
			
		||||
@ -454,8 +457,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def build(
 | 
			
		||||
        self,
 | 
			
		||||
        frame_annotation: types.FrameAnnotation,
 | 
			
		||||
        sequence_annotation: types.SequenceAnnotation,
 | 
			
		||||
        frame_annotation: FrameAnnotationT,
 | 
			
		||||
        sequence_annotation: SequenceAnnotationT,
 | 
			
		||||
        *,
 | 
			
		||||
        load_blobs: bool = True,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
@ -541,8 +544,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
 | 
			
		||||
    def build(
 | 
			
		||||
        self,
 | 
			
		||||
        frame_annotation: types.FrameAnnotation,
 | 
			
		||||
        sequence_annotation: types.SequenceAnnotation,
 | 
			
		||||
        frame_annotation: FrameAnnotationT,
 | 
			
		||||
        sequence_annotation: SequenceAnnotationT,
 | 
			
		||||
        *,
 | 
			
		||||
        load_blobs: bool = True,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
@ -620,11 +623,12 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
                    image_np, frame_annotation.image.size, frame_data.fg_probability
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        depth_annotation = frame_annotation.depth
 | 
			
		||||
        if (
 | 
			
		||||
            load_blobs
 | 
			
		||||
            and self.load_depths
 | 
			
		||||
            and frame_annotation.depth is not None
 | 
			
		||||
            and frame_annotation.depth.path is not None
 | 
			
		||||
            and depth_annotation is not None
 | 
			
		||||
            and depth_annotation.path is not None
 | 
			
		||||
        ):
 | 
			
		||||
            (
 | 
			
		||||
                frame_data.depth_map,
 | 
			
		||||
@ -653,11 +657,10 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
 | 
			
		||||
        return frame_data
 | 
			
		||||
 | 
			
		||||
    def _load_fg_probability(
 | 
			
		||||
        self, entry: types.FrameAnnotation
 | 
			
		||||
    ) -> Tuple[np.ndarray, str]:
 | 
			
		||||
        assert self.dataset_root is not None and entry.mask is not None
 | 
			
		||||
        full_path = os.path.join(self.dataset_root, entry.mask.path)
 | 
			
		||||
    def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]:
 | 
			
		||||
        mask_annotation = entry.mask
 | 
			
		||||
        assert self.dataset_root is not None and mask_annotation is not None
 | 
			
		||||
        full_path = os.path.join(self.dataset_root, mask_annotation.path)
 | 
			
		||||
        fg_probability = load_mask(self._local_path(full_path))
 | 
			
		||||
        if fg_probability.shape[-2:] != entry.image.size:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -685,7 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
 | 
			
		||||
    def _load_mask_depth(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        entry: FrameAnnotationT,
 | 
			
		||||
        fg_mask: Optional[np.ndarray],
 | 
			
		||||
    ) -> Tuple[torch.Tensor, str, torch.Tensor]:
 | 
			
		||||
        entry_depth = entry.depth
 | 
			
		||||
@ -710,7 +713,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
 | 
			
		||||
    def _get_pytorch3d_camera(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        entry: FrameAnnotationT,
 | 
			
		||||
    ) -> PerspectiveCameras:
 | 
			
		||||
        entry_viewpoint = entry.viewpoint
 | 
			
		||||
        assert entry_viewpoint is not None
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user