mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Extract BlobLoader class from JsonIndexDataset and moving crop_by_bbox to FrameData
Summary: extracted blob loader added documentation for blob_loader did some refactoring on fields for detailed steps and discussions see: https://github.com/facebookresearch/pytorch3d/pull/1463 https://github.com/fairinternal/pixar_replay/pull/160 Reviewed By: bottler Differential Revision: D44061728 fbshipit-source-id: eefb21e9679003045d73729f96e6a93a1d4d2d51
This commit is contained in:
		
							parent
							
								
									c759fc560f
								
							
						
					
					
						commit
						ebdbfde0ce
					
				@ -18,8 +18,9 @@ from torch.utils.data import (
 | 
			
		||||
    Sampler,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .dataset_base import DatasetBase, FrameData
 | 
			
		||||
from .dataset_base import DatasetBase
 | 
			
		||||
from .dataset_map_provider import DatasetMap
 | 
			
		||||
from .frame_data import FrameData
 | 
			
		||||
from .scene_batch_sampler import SceneBatchSampler
 | 
			
		||||
from .utils import is_known_frame_scalar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5,217 +5,27 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from dataclasses import dataclass, field, fields
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    List,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FrameData(Mapping[str, Any]):
 | 
			
		||||
    """
 | 
			
		||||
    A type of the elements returned by indexing the dataset object.
 | 
			
		||||
    It can represent both individual frames and batches of thereof;
 | 
			
		||||
    in this documentation, the sizes of tensors refer to single frames;
 | 
			
		||||
    add the first batch dimension for the collation result.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        frame_number: The number of the frame within its sequence.
 | 
			
		||||
            0-based continuous integers.
 | 
			
		||||
        sequence_name: The unique name of the frame's sequence.
 | 
			
		||||
        sequence_category: The object category of the sequence.
 | 
			
		||||
        frame_timestamp: The time elapsed since the start of a sequence in sec.
 | 
			
		||||
        image_size_hw: The size of the image in pixels; (height, width) tensor
 | 
			
		||||
                        of shape (2,).
 | 
			
		||||
        image_path: The qualified path to the loaded image (with dataset_root).
 | 
			
		||||
        image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
 | 
			
		||||
            of the frame; elements are floats in [0, 1].
 | 
			
		||||
        mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
 | 
			
		||||
            regions. Regions can be invalid (mask_crop[i,j]=0) in case they
 | 
			
		||||
            are a result of zero-padding of the image after cropping around
 | 
			
		||||
            the object bounding box; elements are floats in {0.0, 1.0}.
 | 
			
		||||
        depth_path: The qualified path to the frame's depth map.
 | 
			
		||||
        depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
 | 
			
		||||
            of the frame; values correspond to distances from the camera;
 | 
			
		||||
            use `depth_mask` and `mask_crop` to filter for valid pixels.
 | 
			
		||||
        depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
 | 
			
		||||
            depth map that are valid for evaluation, they have been checked for
 | 
			
		||||
            consistency across views; elements are floats in {0.0, 1.0}.
 | 
			
		||||
        mask_path: A qualified path to the foreground probability mask.
 | 
			
		||||
        fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
 | 
			
		||||
            pixels belonging to the captured object; elements are floats
 | 
			
		||||
            in [0, 1].
 | 
			
		||||
        bbox_xywh: The bounding box tightly enclosing the foreground object in the
 | 
			
		||||
            format (x0, y0, width, height). The convention assumes that
 | 
			
		||||
            `x0+width` and `y0+height` includes the boundary of the box.
 | 
			
		||||
            I.e., to slice out the corresponding crop from an image tensor `I`
 | 
			
		||||
            we execute `crop = I[..., y0:y0+height, x0:x0+width]`
 | 
			
		||||
        crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
 | 
			
		||||
            in the original image coordinates in the format (x0, y0, width, height).
 | 
			
		||||
            The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
 | 
			
		||||
            from `bbox_xywh` due to padding (which can happen e.g. due to
 | 
			
		||||
            setting `JsonIndexDataset.box_crop_context > 0`)
 | 
			
		||||
        camera: A PyTorch3D camera object corresponding the frame's viewpoint,
 | 
			
		||||
            corrected for cropping if it happened.
 | 
			
		||||
        camera_quality_score: The score proportional to the confidence of the
 | 
			
		||||
            frame's camera estimation (the higher the more accurate).
 | 
			
		||||
        point_cloud_quality_score: The score proportional to the accuracy of the
 | 
			
		||||
            frame's sequence point cloud (the higher the more accurate).
 | 
			
		||||
        sequence_point_cloud_path: The path to the sequence's point cloud.
 | 
			
		||||
        sequence_point_cloud: A PyTorch3D Pointclouds object holding the
 | 
			
		||||
            point cloud corresponding to the frame's sequence. When the object
 | 
			
		||||
            represents a batch of frames, point clouds may be deduplicated;
 | 
			
		||||
            see `sequence_point_cloud_idx`.
 | 
			
		||||
        sequence_point_cloud_idx: Integer indices mapping frame indices to the
 | 
			
		||||
            corresponding point clouds in `sequence_point_cloud`; to get the
 | 
			
		||||
            corresponding point cloud to `image_rgb[i]`, use
 | 
			
		||||
            `sequence_point_cloud[sequence_point_cloud_idx[i]]`.
 | 
			
		||||
        frame_type: The type of the loaded frame specified in
 | 
			
		||||
            `subset_lists_file`, if provided.
 | 
			
		||||
        meta: A dict for storing additional frame information.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_number: Optional[torch.LongTensor]
 | 
			
		||||
    sequence_name: Union[str, List[str]]
 | 
			
		||||
    sequence_category: Union[str, List[str]]
 | 
			
		||||
    frame_timestamp: Optional[torch.Tensor] = None
 | 
			
		||||
    image_size_hw: Optional[torch.Tensor] = None
 | 
			
		||||
    image_path: Union[str, List[str], None] = None
 | 
			
		||||
    image_rgb: Optional[torch.Tensor] = None
 | 
			
		||||
    # masks out padding added due to cropping the square bit
 | 
			
		||||
    mask_crop: Optional[torch.Tensor] = None
 | 
			
		||||
    depth_path: Union[str, List[str], None] = None
 | 
			
		||||
    depth_map: Optional[torch.Tensor] = None
 | 
			
		||||
    depth_mask: Optional[torch.Tensor] = None
 | 
			
		||||
    mask_path: Union[str, List[str], None] = None
 | 
			
		||||
    fg_probability: Optional[torch.Tensor] = None
 | 
			
		||||
    bbox_xywh: Optional[torch.Tensor] = None
 | 
			
		||||
    crop_bbox_xywh: Optional[torch.Tensor] = None
 | 
			
		||||
    camera: Optional[PerspectiveCameras] = None
 | 
			
		||||
    camera_quality_score: Optional[torch.Tensor] = None
 | 
			
		||||
    point_cloud_quality_score: Optional[torch.Tensor] = None
 | 
			
		||||
    sequence_point_cloud_path: Union[str, List[str], None] = None
 | 
			
		||||
    sequence_point_cloud: Optional[Pointclouds] = None
 | 
			
		||||
    sequence_point_cloud_idx: Optional[torch.Tensor] = None
 | 
			
		||||
    frame_type: Union[str, List[str], None] = None  # known | unseen
 | 
			
		||||
    meta: dict = field(default_factory=lambda: {})
 | 
			
		||||
 | 
			
		||||
    def to(self, *args, **kwargs):
 | 
			
		||||
        new_params = {}
 | 
			
		||||
        for f in fields(self):
 | 
			
		||||
            value = getattr(self, f.name)
 | 
			
		||||
            if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
 | 
			
		||||
                new_params[f.name] = value.to(*args, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                new_params[f.name] = value
 | 
			
		||||
        return type(self)(**new_params)
 | 
			
		||||
 | 
			
		||||
    def cpu(self):
 | 
			
		||||
        return self.to(device=torch.device("cpu"))
 | 
			
		||||
 | 
			
		||||
    def cuda(self):
 | 
			
		||||
        return self.to(device=torch.device("cuda"))
 | 
			
		||||
 | 
			
		||||
    # the following functions make sure **frame_data can be passed to functions
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for f in fields(self):
 | 
			
		||||
            yield f.name
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, key):
 | 
			
		||||
        return getattr(self, key)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(fields(self))
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def collate(cls, batch):
 | 
			
		||||
        """
 | 
			
		||||
        Given a list objects `batch` of class `cls`, collates them into a batched
 | 
			
		||||
        representation suitable for processing with deep networks.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        elem = batch[0]
 | 
			
		||||
 | 
			
		||||
        if isinstance(elem, cls):
 | 
			
		||||
            pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
 | 
			
		||||
            id_to_idx = defaultdict(list)
 | 
			
		||||
            for i, pc_id in enumerate(pointcloud_ids):
 | 
			
		||||
                id_to_idx[pc_id].append(i)
 | 
			
		||||
 | 
			
		||||
            sequence_point_cloud = []
 | 
			
		||||
            sequence_point_cloud_idx = -np.ones((len(batch),))
 | 
			
		||||
            for i, ind in enumerate(id_to_idx.values()):
 | 
			
		||||
                sequence_point_cloud_idx[ind] = i
 | 
			
		||||
                sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
 | 
			
		||||
            assert (sequence_point_cloud_idx >= 0).all()
 | 
			
		||||
 | 
			
		||||
            override_fields = {
 | 
			
		||||
                "sequence_point_cloud": sequence_point_cloud,
 | 
			
		||||
                "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
 | 
			
		||||
            }
 | 
			
		||||
            # note that the pre-collate value of sequence_point_cloud_idx is unused
 | 
			
		||||
 | 
			
		||||
            collated = {}
 | 
			
		||||
            for f in fields(elem):
 | 
			
		||||
                list_values = override_fields.get(
 | 
			
		||||
                    f.name, [getattr(d, f.name) for d in batch]
 | 
			
		||||
                )
 | 
			
		||||
                collated[f.name] = (
 | 
			
		||||
                    cls.collate(list_values)
 | 
			
		||||
                    if all(list_value is not None for list_value in list_values)
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
            return cls(**collated)
 | 
			
		||||
 | 
			
		||||
        elif isinstance(elem, Pointclouds):
 | 
			
		||||
            return join_pointclouds_as_batch(batch)
 | 
			
		||||
 | 
			
		||||
        elif isinstance(elem, CamerasBase):
 | 
			
		||||
            # TODO: don't store K; enforce working in NDC space
 | 
			
		||||
            return join_cameras_as_batch(batch)
 | 
			
		||||
        else:
 | 
			
		||||
            return torch.utils.data._utils.collate.default_collate(batch)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _GenericWorkaround:
 | 
			
		||||
    """
 | 
			
		||||
    OmegaConf.structured has a weirdness when you try to apply
 | 
			
		||||
    it to a dataclass whose first base class is a Generic which is not
 | 
			
		||||
    Dict. The issue is with a function called get_dict_key_value_types
 | 
			
		||||
    in omegaconf/_utils.py.
 | 
			
		||||
    For example this fails:
 | 
			
		||||
 | 
			
		||||
        @dataclass(eq=False)
 | 
			
		||||
        class D(torch.utils.data.Dataset[int]):
 | 
			
		||||
            a: int = 3
 | 
			
		||||
 | 
			
		||||
        OmegaConf.structured(D)
 | 
			
		||||
 | 
			
		||||
    We avoid the problem by adding this class as an extra base class.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    pass
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(eq=False)
 | 
			
		||||
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
    """
 | 
			
		||||
    Base class to describe a dataset to be used with Implicitron.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										728
									
								
								pytorch3d/implicitron/dataset/frame_data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										728
									
								
								pytorch3d/implicitron/dataset/frame_data.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,728 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from dataclasses import dataclass, field, fields
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Generic,
 | 
			
		||||
    List,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset import types
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import (
 | 
			
		||||
    adjust_camera_to_bbox_crop_,
 | 
			
		||||
    adjust_camera_to_image_scale_,
 | 
			
		||||
    bbox_xyxy_to_xywh,
 | 
			
		||||
    clamp_box_to_image_bounds_and_round,
 | 
			
		||||
    crop_around_box,
 | 
			
		||||
    GenericWorkaround,
 | 
			
		||||
    get_bbox_from_mask,
 | 
			
		||||
    get_clamp_bbox,
 | 
			
		||||
    load_depth,
 | 
			
		||||
    load_depth_mask,
 | 
			
		||||
    load_image,
 | 
			
		||||
    load_mask,
 | 
			
		||||
    load_pointcloud,
 | 
			
		||||
    rescale_bbox,
 | 
			
		||||
    resize_image,
 | 
			
		||||
    safe_as_tensor,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FrameData(Mapping[str, Any]):
 | 
			
		||||
    """
 | 
			
		||||
    A type of the elements returned by indexing the dataset object.
 | 
			
		||||
    It can represent both individual frames and batches of thereof;
 | 
			
		||||
    in this documentation, the sizes of tensors refer to single frames;
 | 
			
		||||
    add the first batch dimension for the collation result.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        frame_number: The number of the frame within its sequence.
 | 
			
		||||
            0-based continuous integers.
 | 
			
		||||
        sequence_name: The unique name of the frame's sequence.
 | 
			
		||||
        sequence_category: The object category of the sequence.
 | 
			
		||||
        frame_timestamp: The time elapsed since the start of a sequence in sec.
 | 
			
		||||
        image_size_hw: The size of the original image in pixels; (height, width)
 | 
			
		||||
            tensor of shape (2,). Note that it is optional, e.g. it can be `None`
 | 
			
		||||
            if the frame annotation has no size ans image_rgb has not [yet] been
 | 
			
		||||
            loaded. Image-less FrameData is valid but mutators like crop/resize
 | 
			
		||||
            may fail if the original image size cannot be deduced.
 | 
			
		||||
        effective_image_size_hw: The size of the image after mutations such as
 | 
			
		||||
            crop/resize in pixels; (height, width). if the image has not been mutated,
 | 
			
		||||
            it is equal to `image_size_hw`. Note that it is also optional, for the
 | 
			
		||||
            same reason as `image_size_hw`.
 | 
			
		||||
        image_path: The qualified path to the loaded image (with dataset_root).
 | 
			
		||||
        image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
 | 
			
		||||
            of the frame; elements are floats in [0, 1].
 | 
			
		||||
        mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
 | 
			
		||||
            regions. Regions can be invalid (mask_crop[i,j]=0) in case they
 | 
			
		||||
            are a result of zero-padding of the image after cropping around
 | 
			
		||||
            the object bounding box; elements are floats in {0.0, 1.0}.
 | 
			
		||||
        depth_path: The qualified path to the frame's depth map.
 | 
			
		||||
        depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
 | 
			
		||||
            of the frame; values correspond to distances from the camera;
 | 
			
		||||
            use `depth_mask` and `mask_crop` to filter for valid pixels.
 | 
			
		||||
        depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
 | 
			
		||||
            depth map that are valid for evaluation, they have been checked for
 | 
			
		||||
            consistency across views; elements are floats in {0.0, 1.0}.
 | 
			
		||||
        mask_path: A qualified path to the foreground probability mask.
 | 
			
		||||
        fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
 | 
			
		||||
            pixels belonging to the captured object; elements are floats
 | 
			
		||||
            in [0, 1].
 | 
			
		||||
        bbox_xywh: The bounding box tightly enclosing the foreground object in the
 | 
			
		||||
            format (x0, y0, width, height). The convention assumes that
 | 
			
		||||
            `x0+width` and `y0+height` includes the boundary of the box.
 | 
			
		||||
            I.e., to slice out the corresponding crop from an image tensor `I`
 | 
			
		||||
            we execute `crop = I[..., y0:y0+height, x0:x0+width]`
 | 
			
		||||
        crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
 | 
			
		||||
            in the original image coordinates in the format (x0, y0, width, height).
 | 
			
		||||
            The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
 | 
			
		||||
            from `bbox_xywh` due to padding (which can happen e.g. due to
 | 
			
		||||
            setting `JsonIndexDataset.box_crop_context > 0`)
 | 
			
		||||
        camera: A PyTorch3D camera object corresponding the frame's viewpoint,
 | 
			
		||||
            corrected for cropping if it happened.
 | 
			
		||||
        camera_quality_score: The score proportional to the confidence of the
 | 
			
		||||
            frame's camera estimation (the higher the more accurate).
 | 
			
		||||
        point_cloud_quality_score: The score proportional to the accuracy of the
 | 
			
		||||
            frame's sequence point cloud (the higher the more accurate).
 | 
			
		||||
        sequence_point_cloud_path: The path to the sequence's point cloud.
 | 
			
		||||
        sequence_point_cloud: A PyTorch3D Pointclouds object holding the
 | 
			
		||||
            point cloud corresponding to the frame's sequence. When the object
 | 
			
		||||
            represents a batch of frames, point clouds may be deduplicated;
 | 
			
		||||
            see `sequence_point_cloud_idx`.
 | 
			
		||||
        sequence_point_cloud_idx: Integer indices mapping frame indices to the
 | 
			
		||||
            corresponding point clouds in `sequence_point_cloud`; to get the
 | 
			
		||||
            corresponding point cloud to `image_rgb[i]`, use
 | 
			
		||||
            `sequence_point_cloud[sequence_point_cloud_idx[i]]`.
 | 
			
		||||
        frame_type: The type of the loaded frame specified in
 | 
			
		||||
            `subset_lists_file`, if provided.
 | 
			
		||||
        meta: A dict for storing additional frame information.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_number: Optional[torch.LongTensor]
 | 
			
		||||
    sequence_name: Union[str, List[str]]
 | 
			
		||||
    sequence_category: Union[str, List[str]]
 | 
			
		||||
    frame_timestamp: Optional[torch.Tensor] = None
 | 
			
		||||
    image_size_hw: Optional[torch.LongTensor] = None
 | 
			
		||||
    effective_image_size_hw: Optional[torch.LongTensor] = None
 | 
			
		||||
    image_path: Union[str, List[str], None] = None
 | 
			
		||||
    image_rgb: Optional[torch.Tensor] = None
 | 
			
		||||
    # masks out padding added due to cropping the square bit
 | 
			
		||||
    mask_crop: Optional[torch.Tensor] = None
 | 
			
		||||
    depth_path: Union[str, List[str], None] = None
 | 
			
		||||
    depth_map: Optional[torch.Tensor] = None
 | 
			
		||||
    depth_mask: Optional[torch.Tensor] = None
 | 
			
		||||
    mask_path: Union[str, List[str], None] = None
 | 
			
		||||
    fg_probability: Optional[torch.Tensor] = None
 | 
			
		||||
    bbox_xywh: Optional[torch.Tensor] = None
 | 
			
		||||
    crop_bbox_xywh: Optional[torch.Tensor] = None
 | 
			
		||||
    camera: Optional[PerspectiveCameras] = None
 | 
			
		||||
    camera_quality_score: Optional[torch.Tensor] = None
 | 
			
		||||
    point_cloud_quality_score: Optional[torch.Tensor] = None
 | 
			
		||||
    sequence_point_cloud_path: Union[str, List[str], None] = None
 | 
			
		||||
    sequence_point_cloud: Optional[Pointclouds] = None
 | 
			
		||||
    sequence_point_cloud_idx: Optional[torch.Tensor] = None
 | 
			
		||||
    frame_type: Union[str, List[str], None] = None  # known | unseen
 | 
			
		||||
    meta: dict = field(default_factory=lambda: {})
 | 
			
		||||
 | 
			
		||||
    # NOTE that batching resets this attribute
 | 
			
		||||
    _uncropped: bool = field(init=False, default=True)
 | 
			
		||||
 | 
			
		||||
    def to(self, *args, **kwargs):
 | 
			
		||||
        new_params = {}
 | 
			
		||||
        for field_name in iter(self):
 | 
			
		||||
            value = getattr(self, field_name)
 | 
			
		||||
            if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
 | 
			
		||||
                new_params[field_name] = value.to(*args, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                new_params[field_name] = value
 | 
			
		||||
        frame_data = type(self)(**new_params)
 | 
			
		||||
        frame_data._uncropped = self._uncropped
 | 
			
		||||
        return frame_data
 | 
			
		||||
 | 
			
		||||
    def cpu(self):
 | 
			
		||||
        return self.to(device=torch.device("cpu"))
 | 
			
		||||
 | 
			
		||||
    def cuda(self):
 | 
			
		||||
        return self.to(device=torch.device("cuda"))
 | 
			
		||||
 | 
			
		||||
    # the following functions make sure **frame_data can be passed to functions
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for f in fields(self):
 | 
			
		||||
            if f.name.startswith("_"):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            yield f.name
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, key):
 | 
			
		||||
        return getattr(self, key)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return sum(1 for f in iter(self))
 | 
			
		||||
 | 
			
		||||
    def crop_by_metadata_bbox_(
 | 
			
		||||
        self,
 | 
			
		||||
        box_crop_context: float,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Crops the frame data in-place by (possibly expanded) bounding box.
 | 
			
		||||
        The bounding box is taken from the object state (usually taken from
 | 
			
		||||
        the frame annotation or estimated from the foregroubnd mask).
 | 
			
		||||
        If the expanded bounding box does not fit the image, it is clamped,
 | 
			
		||||
        i.e. the image is *not* padded.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            box_crop_context: rate of expansion for bbox; 0 means no expansion,
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: If the object does not contain a bounding box (usually when no
 | 
			
		||||
                mask annotation is provided)
 | 
			
		||||
            ValueError: If the frame data have been cropped or resized, thus the intrinsic
 | 
			
		||||
                bounding box is not valid for the current image size.
 | 
			
		||||
            ValueError: If the frame does not have an image size (usually a corner case
 | 
			
		||||
                when no image has been loaded)
 | 
			
		||||
        """
 | 
			
		||||
        if self.bbox_xywh is None:
 | 
			
		||||
            raise ValueError("Attempted cropping by metadata with empty bounding box")
 | 
			
		||||
 | 
			
		||||
        if not self._uncropped:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Trying to apply the metadata bounding box to already cropped "
 | 
			
		||||
                "or resized image; coordinates have changed."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self._crop_by_bbox_(
 | 
			
		||||
            box_crop_context,
 | 
			
		||||
            self.bbox_xywh,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def crop_by_given_bbox_(
 | 
			
		||||
        self,
 | 
			
		||||
        box_crop_context: float,
 | 
			
		||||
        bbox_xywh: torch.Tensor,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Crops the frame data in-place by (possibly expanded) bounding box.
 | 
			
		||||
        If the expanded bounding box does not fit the image, it is clamped,
 | 
			
		||||
        i.e. the image is *not* padded.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            box_crop_context: rate of expansion for bbox; 0 means no expansion,
 | 
			
		||||
            bbox_xywh: bounding box in [x0, y0, width, height] format. If float
 | 
			
		||||
                tensor, values are floored (after converting to [x0, y0, x1, y1]).
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: If the frame does not have an image size (usually a corner case
 | 
			
		||||
                when no image has been loaded)
 | 
			
		||||
        """
 | 
			
		||||
        self._crop_by_bbox_(
 | 
			
		||||
            box_crop_context,
 | 
			
		||||
            bbox_xywh,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _crop_by_bbox_(
 | 
			
		||||
        self,
 | 
			
		||||
        box_crop_context: float,
 | 
			
		||||
        bbox_xywh: torch.Tensor,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Crops the frame data in-place by (possibly expanded) bounding box.
 | 
			
		||||
        If the expanded bounding box does not fit the image, it is clamped,
 | 
			
		||||
        i.e. the image is *not* padded.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            box_crop_context: rate of expansion for bbox; 0 means no expansion,
 | 
			
		||||
            bbox_xywh: bounding box in [x0, y0, width, height] format. If float
 | 
			
		||||
                tensor, values are floored (after converting to [x0, y0, x1, y1]).
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: If the frame does not have an image size (usually a corner case
 | 
			
		||||
                when no image has been loaded)
 | 
			
		||||
        """
 | 
			
		||||
        effective_image_size_hw = self.effective_image_size_hw
 | 
			
		||||
        if effective_image_size_hw is None:
 | 
			
		||||
            raise ValueError("Calling crop on image-less FrameData")
 | 
			
		||||
 | 
			
		||||
        bbox_xyxy = get_clamp_bbox(
 | 
			
		||||
            bbox_xywh,
 | 
			
		||||
            image_path=self.image_path,  # pyre-ignore
 | 
			
		||||
            box_crop_context=box_crop_context,
 | 
			
		||||
        )
 | 
			
		||||
        clamp_bbox_xyxy = clamp_box_to_image_bounds_and_round(
 | 
			
		||||
            bbox_xyxy,
 | 
			
		||||
            image_size_hw=tuple(self.effective_image_size_hw),  # pyre-ignore
 | 
			
		||||
        )
 | 
			
		||||
        crop_bbox_xywh = bbox_xyxy_to_xywh(clamp_bbox_xyxy)
 | 
			
		||||
 | 
			
		||||
        if self.fg_probability is not None:
 | 
			
		||||
            self.fg_probability = crop_around_box(
 | 
			
		||||
                self.fg_probability,
 | 
			
		||||
                clamp_bbox_xyxy,
 | 
			
		||||
                self.mask_path,  # pyre-ignore
 | 
			
		||||
            )
 | 
			
		||||
        if self.image_rgb is not None:
 | 
			
		||||
            self.image_rgb = crop_around_box(
 | 
			
		||||
                self.image_rgb,
 | 
			
		||||
                clamp_bbox_xyxy,
 | 
			
		||||
                self.image_path,  # pyre-ignore
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        depth_map = self.depth_map
 | 
			
		||||
        if depth_map is not None:
 | 
			
		||||
            clamp_bbox_xyxy_depth = rescale_bbox(
 | 
			
		||||
                clamp_bbox_xyxy, tuple(depth_map.shape[-2:]), effective_image_size_hw
 | 
			
		||||
            ).long()
 | 
			
		||||
            self.depth_map = crop_around_box(
 | 
			
		||||
                depth_map,
 | 
			
		||||
                clamp_bbox_xyxy_depth,
 | 
			
		||||
                self.depth_path,  # pyre-ignore
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        depth_mask = self.depth_mask
 | 
			
		||||
        if depth_mask is not None:
 | 
			
		||||
            clamp_bbox_xyxy_depth = rescale_bbox(
 | 
			
		||||
                clamp_bbox_xyxy, tuple(depth_mask.shape[-2:]), effective_image_size_hw
 | 
			
		||||
            ).long()
 | 
			
		||||
            self.depth_mask = crop_around_box(
 | 
			
		||||
                depth_mask,
 | 
			
		||||
                clamp_bbox_xyxy_depth,
 | 
			
		||||
                self.mask_path,  # pyre-ignore
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # changing principal_point according to bbox_crop
 | 
			
		||||
        if self.camera is not None:
 | 
			
		||||
            adjust_camera_to_bbox_crop_(
 | 
			
		||||
                camera=self.camera,
 | 
			
		||||
                image_size_wh=effective_image_size_hw.flip(dims=[-1]),
 | 
			
		||||
                clamp_bbox_xywh=crop_bbox_xywh,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore
 | 
			
		||||
        self.effective_image_size_hw = crop_bbox_xywh[..., 2:].flip(dims=[-1])
 | 
			
		||||
        self._uncropped = False
 | 
			
		||||
 | 
			
		||||
    def resize_frame_(self, new_size_hw: torch.LongTensor) -> None:
 | 
			
		||||
        """Resizes frame data in-place according to given dimensions.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            new_size_hw: target image size [height, width], a LongTensor of shape (2,)
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError: If the frame does not have an image size (usually a corner case
 | 
			
		||||
                when no image has been loaded)
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        effective_image_size_hw = self.effective_image_size_hw
 | 
			
		||||
        if effective_image_size_hw is None:
 | 
			
		||||
            raise ValueError("Calling resize on image-less FrameData")
 | 
			
		||||
 | 
			
		||||
        image_height, image_width = new_size_hw.tolist()
 | 
			
		||||
 | 
			
		||||
        if self.fg_probability is not None:
 | 
			
		||||
            self.fg_probability, _, _ = resize_image(
 | 
			
		||||
                self.fg_probability,
 | 
			
		||||
                image_height=image_height,
 | 
			
		||||
                image_width=image_width,
 | 
			
		||||
                mode="nearest",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.image_rgb is not None:
 | 
			
		||||
            self.image_rgb, _, self.mask_crop = resize_image(
 | 
			
		||||
                self.image_rgb, image_height=image_height, image_width=image_width
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.depth_map is not None:
 | 
			
		||||
            self.depth_map, _, _ = resize_image(
 | 
			
		||||
                self.depth_map,
 | 
			
		||||
                image_height=image_height,
 | 
			
		||||
                image_width=image_width,
 | 
			
		||||
                mode="nearest",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.depth_mask is not None:
 | 
			
		||||
            self.depth_mask, _, _ = resize_image(
 | 
			
		||||
                self.depth_mask,
 | 
			
		||||
                image_height=image_height,
 | 
			
		||||
                image_width=image_width,
 | 
			
		||||
                mode="nearest",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.camera is not None:
 | 
			
		||||
            if self.image_size_hw is None:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "image_size_hw has to be defined for resizing FrameData with cameras."
 | 
			
		||||
                )
 | 
			
		||||
            adjust_camera_to_image_scale_(
 | 
			
		||||
                camera=self.camera,
 | 
			
		||||
                original_size_wh=effective_image_size_hw.flip(dims=[-1]),
 | 
			
		||||
                new_size_wh=new_size_hw.flip(dims=[-1]),  # pyre-ignore
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.effective_image_size_hw = new_size_hw
 | 
			
		||||
        self._uncropped = False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def collate(cls, batch):
 | 
			
		||||
        """
 | 
			
		||||
        Given a list objects `batch` of class `cls`, collates them into a batched
 | 
			
		||||
        representation suitable for processing with deep networks.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        elem = batch[0]
 | 
			
		||||
 | 
			
		||||
        if isinstance(elem, cls):
 | 
			
		||||
            pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
 | 
			
		||||
            id_to_idx = defaultdict(list)
 | 
			
		||||
            for i, pc_id in enumerate(pointcloud_ids):
 | 
			
		||||
                id_to_idx[pc_id].append(i)
 | 
			
		||||
 | 
			
		||||
            sequence_point_cloud = []
 | 
			
		||||
            sequence_point_cloud_idx = -np.ones((len(batch),))
 | 
			
		||||
            for i, ind in enumerate(id_to_idx.values()):
 | 
			
		||||
                sequence_point_cloud_idx[ind] = i
 | 
			
		||||
                sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
 | 
			
		||||
            assert (sequence_point_cloud_idx >= 0).all()
 | 
			
		||||
 | 
			
		||||
            override_fields = {
 | 
			
		||||
                "sequence_point_cloud": sequence_point_cloud,
 | 
			
		||||
                "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
 | 
			
		||||
            }
 | 
			
		||||
            # note that the pre-collate value of sequence_point_cloud_idx is unused
 | 
			
		||||
 | 
			
		||||
            collated = {}
 | 
			
		||||
            for f in fields(elem):
 | 
			
		||||
                if not f.init:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                list_values = override_fields.get(
 | 
			
		||||
                    f.name, [getattr(d, f.name) for d in batch]
 | 
			
		||||
                )
 | 
			
		||||
                collated[f.name] = (
 | 
			
		||||
                    cls.collate(list_values)
 | 
			
		||||
                    if all(list_value is not None for list_value in list_values)
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
            return cls(**collated)
 | 
			
		||||
 | 
			
		||||
        elif isinstance(elem, Pointclouds):
 | 
			
		||||
            return join_pointclouds_as_batch(batch)
 | 
			
		||||
 | 
			
		||||
        elif isinstance(elem, CamerasBase):
 | 
			
		||||
            # TODO: don't store K; enforce working in NDC space
 | 
			
		||||
            return join_cameras_as_batch(batch)
 | 
			
		||||
        else:
 | 
			
		||||
            return torch.utils.data._utils.collate.default_collate(batch)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FrameDataSubtype = TypeVar("FrameDataSubtype", bound=FrameData)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
 | 
			
		||||
    """A base class for FrameDataBuilders that build a FrameData object, load and
 | 
			
		||||
    process the binary data (crop and resize). Implementations should parametrize
 | 
			
		||||
    the class with a subtype of FrameData and set frame_data_type class variable to
 | 
			
		||||
    that type. They have to also implement `build` method.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # To be initialised to FrameDataSubtype
 | 
			
		||||
    frame_data_type: ClassVar[Type[FrameDataSubtype]]
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def build(
 | 
			
		||||
        self,
 | 
			
		||||
        frame_annotation: types.FrameAnnotation,
 | 
			
		||||
        sequence_annotation: types.SequenceAnnotation,
 | 
			
		||||
    ) -> FrameDataSubtype:
 | 
			
		||||
        """An abstract method to build the frame data based on raw frame/sequence
 | 
			
		||||
        annotations, load the binary data and adjust them according to the metadata.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
    """
 | 
			
		||||
    A class to build a FrameData object, load and process the binary data (crop and
 | 
			
		||||
    resize). This is an abstract class for extending to build FrameData subtypes. Most
 | 
			
		||||
    users need to use concrete `FrameDataBuilder` class instead.
 | 
			
		||||
    Beware that modifications of frame data are done in-place.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        dataset_root: The root folder of the dataset; all the paths in jsons are
 | 
			
		||||
                specified relative to this root (but not json paths themselves).
 | 
			
		||||
        load_images: Enable loading the frame RGB data.
 | 
			
		||||
        load_depths: Enable loading the frame depth maps.
 | 
			
		||||
        load_depth_masks: Enable loading the frame depth map masks denoting the
 | 
			
		||||
            depth values used for evaluation (the points consistent across views).
 | 
			
		||||
        load_masks: Enable loading frame foreground masks.
 | 
			
		||||
        load_point_clouds: Enable loading sequence-level point clouds.
 | 
			
		||||
        max_points: Cap on the number of loaded points in the point cloud;
 | 
			
		||||
            if reached, they are randomly sampled without replacement.
 | 
			
		||||
        mask_images: Whether to mask the images with the loaded foreground masks;
 | 
			
		||||
            0 value is used for background.
 | 
			
		||||
        mask_depths: Whether to mask the depth maps with the loaded foreground
 | 
			
		||||
            masks; 0 value is used for background.
 | 
			
		||||
        image_height: The height of the returned images, masks, and depth maps;
 | 
			
		||||
            aspect ratio is preserved during cropping/resizing.
 | 
			
		||||
        image_width: The width of the returned images, masks, and depth maps;
 | 
			
		||||
            aspect ratio is preserved during cropping/resizing.
 | 
			
		||||
        box_crop: Enable cropping of the image around the bounding box inferred
 | 
			
		||||
            from the foreground region of the loaded segmentation mask; masks
 | 
			
		||||
            and depth maps are cropped accordingly; cameras are corrected.
 | 
			
		||||
        box_crop_mask_thr: The threshold used to separate pixels into foreground
 | 
			
		||||
            and background based on the foreground_probability mask; if no value
 | 
			
		||||
            is greater than this threshold, the loader lowers it and repeats.
 | 
			
		||||
        box_crop_context: The amount of additional padding added to each
 | 
			
		||||
            dimension of the cropping bounding box, relative to box size.
 | 
			
		||||
        path_manager: Optionally a PathManager for interpreting paths in a special way.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    dataset_root: str = ""
 | 
			
		||||
    load_images: bool = True
 | 
			
		||||
    load_depths: bool = True
 | 
			
		||||
    load_depth_masks: bool = True
 | 
			
		||||
    load_masks: bool = True
 | 
			
		||||
    load_point_clouds: bool = False
 | 
			
		||||
    max_points: int = 0
 | 
			
		||||
    mask_images: bool = False
 | 
			
		||||
    mask_depths: bool = False
 | 
			
		||||
    image_height: Optional[int] = 800
 | 
			
		||||
    image_width: Optional[int] = 800
 | 
			
		||||
    box_crop: bool = True
 | 
			
		||||
    box_crop_mask_thr: float = 0.4
 | 
			
		||||
    box_crop_context: float = 0.3
 | 
			
		||||
    path_manager: Any = None
 | 
			
		||||
 | 
			
		||||
    def build(
 | 
			
		||||
        self,
 | 
			
		||||
        frame_annotation: types.FrameAnnotation,
 | 
			
		||||
        sequence_annotation: types.SequenceAnnotation,
 | 
			
		||||
        load_blobs: bool = True,
 | 
			
		||||
    ) -> FrameDataSubtype:
 | 
			
		||||
        """Builds the frame data based on raw frame/sequence annotations, loads the
 | 
			
		||||
        binary data and adjust them according to the metadata. The processing includes:
 | 
			
		||||
            * if box_crop is set, the image/mask/depth are cropped with the bounding
 | 
			
		||||
                box provided or estimated from MaskAnnotation,
 | 
			
		||||
            * if image_height/image_width are set, the image/mask/depth are resized to
 | 
			
		||||
                fit that resolution. Note that the aspect ratio is preserved, and the
 | 
			
		||||
                (possibly cropped) image is pasted into the top-left corner. In the
 | 
			
		||||
                resulting frame_data, mask_crop field corresponds to the mask of the
 | 
			
		||||
                pasted image.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            frame_annotation: frame annotation
 | 
			
		||||
            sequence_annotation: sequence annotation
 | 
			
		||||
            load_blobs: if the function should attempt loading the image, depth map
 | 
			
		||||
                and mask, and foreground mask
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The constructed FrameData object.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        point_cloud = sequence_annotation.point_cloud
 | 
			
		||||
 | 
			
		||||
        frame_data = self.frame_data_type(
 | 
			
		||||
            frame_number=safe_as_tensor(frame_annotation.frame_number, torch.long),
 | 
			
		||||
            frame_timestamp=safe_as_tensor(
 | 
			
		||||
                frame_annotation.frame_timestamp, torch.float
 | 
			
		||||
            ),
 | 
			
		||||
            sequence_name=frame_annotation.sequence_name,
 | 
			
		||||
            sequence_category=sequence_annotation.category,
 | 
			
		||||
            camera_quality_score=safe_as_tensor(
 | 
			
		||||
                sequence_annotation.viewpoint_quality_score, torch.float
 | 
			
		||||
            ),
 | 
			
		||||
            point_cloud_quality_score=safe_as_tensor(
 | 
			
		||||
                point_cloud.quality_score, torch.float
 | 
			
		||||
            )
 | 
			
		||||
            if point_cloud is not None
 | 
			
		||||
            else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if load_blobs and self.load_masks and frame_annotation.mask is not None:
 | 
			
		||||
            (
 | 
			
		||||
                frame_data.fg_probability,
 | 
			
		||||
                frame_data.mask_path,
 | 
			
		||||
                frame_data.bbox_xywh,
 | 
			
		||||
            ) = self._load_fg_probability(frame_annotation)
 | 
			
		||||
 | 
			
		||||
        if frame_annotation.image is not None:
 | 
			
		||||
            image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
 | 
			
		||||
            frame_data.image_size_hw = image_size_hw  # original image size
 | 
			
		||||
            # image size after crop/resize
 | 
			
		||||
            frame_data.effective_image_size_hw = image_size_hw
 | 
			
		||||
 | 
			
		||||
            if load_blobs and self.load_images:
 | 
			
		||||
                (
 | 
			
		||||
                    frame_data.image_rgb,
 | 
			
		||||
                    frame_data.image_path,
 | 
			
		||||
                ) = self._load_images(frame_annotation, frame_data.fg_probability)
 | 
			
		||||
 | 
			
		||||
        if load_blobs and self.load_depths and frame_annotation.depth is not None:
 | 
			
		||||
            (
 | 
			
		||||
                frame_data.depth_map,
 | 
			
		||||
                frame_data.depth_path,
 | 
			
		||||
                frame_data.depth_mask,
 | 
			
		||||
            ) = self._load_mask_depth(frame_annotation, frame_data.fg_probability)
 | 
			
		||||
 | 
			
		||||
        if load_blobs and self.load_point_clouds and point_cloud is not None:
 | 
			
		||||
            pcl_path = self._fix_point_cloud_path(point_cloud.path)
 | 
			
		||||
            frame_data.sequence_point_cloud = load_pointcloud(
 | 
			
		||||
                self._local_path(pcl_path), max_points=self.max_points
 | 
			
		||||
            )
 | 
			
		||||
            frame_data.sequence_point_cloud_path = pcl_path
 | 
			
		||||
 | 
			
		||||
        if frame_annotation.viewpoint is not None:
 | 
			
		||||
            frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
 | 
			
		||||
 | 
			
		||||
        if self.box_crop:
 | 
			
		||||
            frame_data.crop_by_metadata_bbox_(self.box_crop_context)
 | 
			
		||||
 | 
			
		||||
        if self.image_height is not None and self.image_width is not None:
 | 
			
		||||
            new_size = (self.image_height, self.image_width)
 | 
			
		||||
            frame_data.resize_frame_(
 | 
			
		||||
                new_size_hw=torch.tensor(new_size, dtype=torch.long),  # pyre-ignore
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return frame_data
 | 
			
		||||
 | 
			
		||||
    def _load_fg_probability(
 | 
			
		||||
        self, entry: types.FrameAnnotation
 | 
			
		||||
    ) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
 | 
			
		||||
 | 
			
		||||
        full_path = os.path.join(self.dataset_root, entry.mask.path)  # pyre-ignore
 | 
			
		||||
        fg_probability = load_mask(self._local_path(full_path))
 | 
			
		||||
        # we can use provided bbox_xywh or calculate it based on mask
 | 
			
		||||
        # saves time to skip bbox calculation
 | 
			
		||||
        # pyre-ignore
 | 
			
		||||
        bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
 | 
			
		||||
            fg_probability, self.box_crop_mask_thr
 | 
			
		||||
        )
 | 
			
		||||
        if fg_probability.shape[-2:] != entry.image.size:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
 | 
			
		||||
            )
 | 
			
		||||
        return (
 | 
			
		||||
            safe_as_tensor(fg_probability, torch.float),
 | 
			
		||||
            full_path,
 | 
			
		||||
            safe_as_tensor(bbox_xywh, torch.long),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _load_images(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
    ) -> Tuple[torch.Tensor, str]:
 | 
			
		||||
        assert self.dataset_root is not None and entry.image is not None
 | 
			
		||||
        path = os.path.join(self.dataset_root, entry.image.path)
 | 
			
		||||
        image_rgb = load_image(self._local_path(path))
 | 
			
		||||
 | 
			
		||||
        if image_rgb.shape[-2:] != entry.image.size:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.mask_images:
 | 
			
		||||
            assert fg_probability is not None
 | 
			
		||||
            image_rgb *= fg_probability
 | 
			
		||||
 | 
			
		||||
        return image_rgb, path
 | 
			
		||||
 | 
			
		||||
    def _load_mask_depth(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
    ) -> Tuple[torch.Tensor, str, torch.Tensor]:
 | 
			
		||||
        entry_depth = entry.depth
 | 
			
		||||
        assert entry_depth is not None
 | 
			
		||||
        path = os.path.join(self.dataset_root, entry_depth.path)
 | 
			
		||||
        depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
 | 
			
		||||
 | 
			
		||||
        if self.mask_depths:
 | 
			
		||||
            assert fg_probability is not None
 | 
			
		||||
            depth_map *= fg_probability
 | 
			
		||||
 | 
			
		||||
        if self.load_depth_masks:
 | 
			
		||||
            assert entry_depth.mask_path is not None
 | 
			
		||||
            mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
 | 
			
		||||
            depth_mask = load_depth_mask(self._local_path(mask_path))
 | 
			
		||||
        else:
 | 
			
		||||
            depth_mask = torch.ones_like(depth_map)
 | 
			
		||||
 | 
			
		||||
        return torch.tensor(depth_map), path, torch.tensor(depth_mask)
 | 
			
		||||
 | 
			
		||||
    def _get_pytorch3d_camera(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
    ) -> PerspectiveCameras:
 | 
			
		||||
        entry_viewpoint = entry.viewpoint
 | 
			
		||||
        assert entry_viewpoint is not None
 | 
			
		||||
        # principal point and focal length
 | 
			
		||||
        principal_point = torch.tensor(
 | 
			
		||||
            entry_viewpoint.principal_point, dtype=torch.float
 | 
			
		||||
        )
 | 
			
		||||
        focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
 | 
			
		||||
 | 
			
		||||
        format = entry_viewpoint.intrinsics_format
 | 
			
		||||
        if entry_viewpoint.intrinsics_format == "ndc_norm_image_bounds":
 | 
			
		||||
            # legacy PyTorch3D NDC format
 | 
			
		||||
            # convert to pixels unequally and convert to ndc equally
 | 
			
		||||
            image_size_as_list = list(reversed(entry.image.size))
 | 
			
		||||
            image_size_wh = torch.tensor(image_size_as_list, dtype=torch.float)
 | 
			
		||||
            per_axis_scale = image_size_wh / image_size_wh.min()
 | 
			
		||||
            focal_length = focal_length * per_axis_scale
 | 
			
		||||
            principal_point = principal_point * per_axis_scale
 | 
			
		||||
        elif entry_viewpoint.intrinsics_format != "ndc_isotropic":
 | 
			
		||||
            raise ValueError(f"Unknown intrinsics format: {format}")
 | 
			
		||||
 | 
			
		||||
        return PerspectiveCameras(
 | 
			
		||||
            focal_length=focal_length[None],
 | 
			
		||||
            principal_point=principal_point[None],
 | 
			
		||||
            R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
 | 
			
		||||
            T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _fix_point_cloud_path(self, path: str) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Fix up a point cloud path from the dataset.
 | 
			
		||||
        Some files in Co3Dv2 have an accidental absolute path stored.
 | 
			
		||||
        """
 | 
			
		||||
        unwanted_prefix = (
 | 
			
		||||
            "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
 | 
			
		||||
        )
 | 
			
		||||
        if path.startswith(unwanted_prefix):
 | 
			
		||||
            path = path[len(unwanted_prefix) :]
 | 
			
		||||
        return os.path.join(self.dataset_root, path)
 | 
			
		||||
 | 
			
		||||
    def _local_path(self, path: str) -> str:
 | 
			
		||||
        if self.path_manager is None:
 | 
			
		||||
            return path
 | 
			
		||||
        return self.path_manager.get_local_path(path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
 | 
			
		||||
    """
 | 
			
		||||
    A concrete class to build a FrameData object, load and process the binary data (crop
 | 
			
		||||
    and resize). Beware that modifications of frame data are done in-place. Please see
 | 
			
		||||
    the documentation for `GenericFrameDataBuilder` for the description of parameters
 | 
			
		||||
    and methods.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_data_type: ClassVar[Type[FrameData]] = FrameData
 | 
			
		||||
@ -15,7 +15,6 @@ import random
 | 
			
		||||
import warnings
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from itertools import islice
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
@ -30,19 +29,15 @@ from typing import (
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.implicitron.dataset import types
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData, FrameDataBuilder
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from pytorch3d.io import IO
 | 
			
		||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from . import types
 | 
			
		||||
from .dataset_base import DatasetBase, FrameData
 | 
			
		||||
from .utils import is_known_frame_scalar
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
@ -65,7 +60,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    A dataset with annotations in json files like the Common Objects in 3D
 | 
			
		||||
    (CO3D) dataset.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
    Metadata-related args::
 | 
			
		||||
        frame_annotations_file: A zipped json file containing metadata of the
 | 
			
		||||
            frames in the dataset, serialized List[types.FrameAnnotation].
 | 
			
		||||
        sequence_annotations_file: A zipped json file containing metadata of the
 | 
			
		||||
@ -83,6 +78,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        pick_sequence: A list of sequence names to restrict the dataset to.
 | 
			
		||||
        exclude_sequence: A list of the names of the sequences to exclude.
 | 
			
		||||
        limit_category_to: Restrict the dataset to the given list of categories.
 | 
			
		||||
        remove_empty_masks: Removes the frames with no active foreground pixels
 | 
			
		||||
            in the segmentation mask after thresholding (see box_crop_mask_thr).
 | 
			
		||||
        n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
 | 
			
		||||
            frames in each sequences uniformly without replacement if it has
 | 
			
		||||
            more frames than that; applied before other frame-level filters.
 | 
			
		||||
        seed: The seed of the random generator sampling #n_frames_per_sequence
 | 
			
		||||
            random frames per sequence.
 | 
			
		||||
        sort_frames: Enable frame annotations sorting to group frames from the
 | 
			
		||||
            same sequences together and order them by timestamps
 | 
			
		||||
        eval_batches: A list of batches that form the evaluation set;
 | 
			
		||||
            list of batch-sized lists of indices corresponding to __getitem__
 | 
			
		||||
            of this class, thus it can be used directly as a batch sampler.
 | 
			
		||||
        eval_batch_index:
 | 
			
		||||
            ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
 | 
			
		||||
            A list of batches of frames described as (sequence_name, frame_idx)
 | 
			
		||||
            that can form the evaluation set, `eval_batches` will be set from this.
 | 
			
		||||
 | 
			
		||||
    Blob-loading parameters:
 | 
			
		||||
        dataset_root: The root folder of the dataset; all the paths in jsons are
 | 
			
		||||
            specified relative to this root (but not json paths themselves).
 | 
			
		||||
        load_images: Enable loading the frame RGB data.
 | 
			
		||||
@ -109,23 +122,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            is greater than this threshold, the loader lowers it and repeats.
 | 
			
		||||
        box_crop_context: The amount of additional padding added to each
 | 
			
		||||
            dimension of the cropping bounding box, relative to box size.
 | 
			
		||||
        remove_empty_masks: Removes the frames with no active foreground pixels
 | 
			
		||||
            in the segmentation mask after thresholding (see box_crop_mask_thr).
 | 
			
		||||
        n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
 | 
			
		||||
            frames in each sequences uniformly without replacement if it has
 | 
			
		||||
            more frames than that; applied before other frame-level filters.
 | 
			
		||||
        seed: The seed of the random generator sampling #n_frames_per_sequence
 | 
			
		||||
            random frames per sequence.
 | 
			
		||||
        sort_frames: Enable frame annotations sorting to group frames from the
 | 
			
		||||
            same sequences together and order them by timestamps
 | 
			
		||||
        eval_batches: A list of batches that form the evaluation set;
 | 
			
		||||
            list of batch-sized lists of indices corresponding to __getitem__
 | 
			
		||||
            of this class, thus it can be used directly as a batch sampler.
 | 
			
		||||
        eval_batch_index:
 | 
			
		||||
            ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
 | 
			
		||||
            A list of batches of frames described as (sequence_name, frame_idx)
 | 
			
		||||
            that can form the evaluation set, `eval_batches` will be set from this.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_annotations_type: ClassVar[
 | 
			
		||||
@ -162,12 +158,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    sort_frames: bool = False
 | 
			
		||||
    eval_batches: Any = None
 | 
			
		||||
    eval_batch_index: Any = None
 | 
			
		||||
    # initialised in __post_init__
 | 
			
		||||
    # commented because of OmegaConf (for tests to pass)
 | 
			
		||||
    # _frame_data_builder: FrameDataBuilder = field(init=False)
 | 
			
		||||
    # frame_annots: List[FrameAnnotsEntry] = field(init=False)
 | 
			
		||||
    # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
 | 
			
		||||
    # _seq_to_idx: Dict[str, List[int]] = field(init=False)
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
 | 
			
		||||
        self.subset_to_image_path = None
 | 
			
		||||
        self._load_frames()
 | 
			
		||||
        self._load_sequences()
 | 
			
		||||
        if self.sort_frames:
 | 
			
		||||
@ -175,9 +173,27 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        self._load_subset_lists()
 | 
			
		||||
        self._filter_db()  # also computes sequence indices
 | 
			
		||||
        self._extract_and_set_eval_batches()
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore
 | 
			
		||||
        self._frame_data_builder = FrameDataBuilder(
 | 
			
		||||
            dataset_root=self.dataset_root,
 | 
			
		||||
            load_images=self.load_images,
 | 
			
		||||
            load_depths=self.load_depths,
 | 
			
		||||
            load_depth_masks=self.load_depth_masks,
 | 
			
		||||
            load_masks=self.load_masks,
 | 
			
		||||
            load_point_clouds=self.load_point_clouds,
 | 
			
		||||
            max_points=self.max_points,
 | 
			
		||||
            mask_images=self.mask_images,
 | 
			
		||||
            mask_depths=self.mask_depths,
 | 
			
		||||
            image_height=self.image_height,
 | 
			
		||||
            image_width=self.image_width,
 | 
			
		||||
            box_crop=self.box_crop,
 | 
			
		||||
            box_crop_mask_thr=self.box_crop_mask_thr,
 | 
			
		||||
            box_crop_context=self.box_crop_context,
 | 
			
		||||
        )
 | 
			
		||||
        logger.info(str(self))
 | 
			
		||||
 | 
			
		||||
    def _extract_and_set_eval_batches(self):
 | 
			
		||||
    def _extract_and_set_eval_batches(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Sets eval_batches based on input eval_batch_index.
 | 
			
		||||
        """
 | 
			
		||||
@ -207,13 +223,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            # https://gist.github.com/treyhunner/f35292e676efa0be1728
 | 
			
		||||
            functools.reduce(
 | 
			
		||||
                lambda a, b: {**a, **b},
 | 
			
		||||
                [d.seq_annots for d in other_datasets],  # pyre-ignore[16]
 | 
			
		||||
                # pyre-ignore[16]
 | 
			
		||||
                [d.seq_annots for d in other_datasets],
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        all_eval_batches = [
 | 
			
		||||
            self.eval_batches,
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            *[d.eval_batches for d in other_datasets],
 | 
			
		||||
            *[d.eval_batches for d in other_datasets],  # pyre-ignore[16]
 | 
			
		||||
        ]
 | 
			
		||||
        if not (
 | 
			
		||||
            all(ba is None for ba in all_eval_batches)
 | 
			
		||||
@ -251,7 +267,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        allow_missing_indices: bool = False,
 | 
			
		||||
        remove_missing_indices: bool = False,
 | 
			
		||||
        suppress_missing_index_warning: bool = True,
 | 
			
		||||
    ) -> List[List[Union[Optional[int], int]]]:
 | 
			
		||||
    ) -> Union[List[List[Optional[int]]], List[List[int]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Obtain indices into the dataset object given a list of frame ids.
 | 
			
		||||
 | 
			
		||||
@ -323,9 +339,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            valid_dataset_idx = [
 | 
			
		||||
                [b for b in batch if b is not None] for batch in dataset_idx
 | 
			
		||||
            ]
 | 
			
		||||
            return [  # pyre-ignore[7]
 | 
			
		||||
                batch for batch in valid_dataset_idx if len(batch) > 0
 | 
			
		||||
            ]
 | 
			
		||||
            return [batch for batch in valid_dataset_idx if len(batch) > 0]
 | 
			
		||||
 | 
			
		||||
        return dataset_idx
 | 
			
		||||
 | 
			
		||||
@ -417,255 +431,18 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
 | 
			
		||||
 | 
			
		||||
        entry = self.frame_annots[index]["frame_annotation"]
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        point_cloud = self.seq_annots[entry.sequence_name].point_cloud
 | 
			
		||||
        frame_data = FrameData(
 | 
			
		||||
            frame_number=_safe_as_tensor(entry.frame_number, torch.long),
 | 
			
		||||
            frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
 | 
			
		||||
            sequence_name=entry.sequence_name,
 | 
			
		||||
            sequence_category=self.seq_annots[entry.sequence_name].category,
 | 
			
		||||
            camera_quality_score=_safe_as_tensor(
 | 
			
		||||
                self.seq_annots[entry.sequence_name].viewpoint_quality_score,
 | 
			
		||||
                torch.float,
 | 
			
		||||
            ),
 | 
			
		||||
            point_cloud_quality_score=_safe_as_tensor(
 | 
			
		||||
                point_cloud.quality_score, torch.float
 | 
			
		||||
            )
 | 
			
		||||
            if point_cloud is not None
 | 
			
		||||
            else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # The rest of the fields are optional
 | 
			
		||||
        # pyre-ignore
 | 
			
		||||
        frame_data = self._frame_data_builder.build(
 | 
			
		||||
            entry,
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            self.seq_annots[entry.sequence_name],
 | 
			
		||||
        )
 | 
			
		||||
        # Optional field
 | 
			
		||||
        frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            frame_data.fg_probability,
 | 
			
		||||
            frame_data.mask_path,
 | 
			
		||||
            frame_data.bbox_xywh,
 | 
			
		||||
            clamp_bbox_xyxy,
 | 
			
		||||
            frame_data.crop_bbox_xywh,
 | 
			
		||||
        ) = self._load_crop_fg_probability(entry)
 | 
			
		||||
 | 
			
		||||
        scale = 1.0
 | 
			
		||||
        if self.load_images and entry.image is not None:
 | 
			
		||||
            # original image size
 | 
			
		||||
            frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
 | 
			
		||||
 | 
			
		||||
            (
 | 
			
		||||
                frame_data.image_rgb,
 | 
			
		||||
                frame_data.image_path,
 | 
			
		||||
                frame_data.mask_crop,
 | 
			
		||||
                scale,
 | 
			
		||||
            ) = self._load_crop_images(
 | 
			
		||||
                entry, frame_data.fg_probability, clamp_bbox_xyxy
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.load_depths and entry.depth is not None:
 | 
			
		||||
            (
 | 
			
		||||
                frame_data.depth_map,
 | 
			
		||||
                frame_data.depth_path,
 | 
			
		||||
                frame_data.depth_mask,
 | 
			
		||||
            ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
 | 
			
		||||
 | 
			
		||||
        if entry.viewpoint is not None:
 | 
			
		||||
            frame_data.camera = self._get_pytorch3d_camera(
 | 
			
		||||
                entry,
 | 
			
		||||
                scale,
 | 
			
		||||
                clamp_bbox_xyxy,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.load_point_clouds and point_cloud is not None:
 | 
			
		||||
            pcl_path = self._fix_point_cloud_path(point_cloud.path)
 | 
			
		||||
            frame_data.sequence_point_cloud = _load_pointcloud(
 | 
			
		||||
                self._local_path(pcl_path), max_points=self.max_points
 | 
			
		||||
            )
 | 
			
		||||
            frame_data.sequence_point_cloud_path = pcl_path
 | 
			
		||||
 | 
			
		||||
        return frame_data
 | 
			
		||||
 | 
			
		||||
    def _fix_point_cloud_path(self, path: str) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Fix up a point cloud path from the dataset.
 | 
			
		||||
        Some files in Co3Dv2 have an accidental absolute path stored.
 | 
			
		||||
        """
 | 
			
		||||
        unwanted_prefix = (
 | 
			
		||||
            "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
 | 
			
		||||
        )
 | 
			
		||||
        if path.startswith(unwanted_prefix):
 | 
			
		||||
            path = path[len(unwanted_prefix) :]
 | 
			
		||||
        return os.path.join(self.dataset_root, path)
 | 
			
		||||
 | 
			
		||||
    def _load_crop_fg_probability(
 | 
			
		||||
        self, entry: types.FrameAnnotation
 | 
			
		||||
    ) -> Tuple[
 | 
			
		||||
        Optional[torch.Tensor],
 | 
			
		||||
        Optional[str],
 | 
			
		||||
        Optional[torch.Tensor],
 | 
			
		||||
        Optional[torch.Tensor],
 | 
			
		||||
        Optional[torch.Tensor],
 | 
			
		||||
    ]:
 | 
			
		||||
        fg_probability = None
 | 
			
		||||
        full_path = None
 | 
			
		||||
        bbox_xywh = None
 | 
			
		||||
        clamp_bbox_xyxy = None
 | 
			
		||||
        crop_box_xywh = None
 | 
			
		||||
 | 
			
		||||
        if (self.load_masks or self.box_crop) and entry.mask is not None:
 | 
			
		||||
            full_path = os.path.join(self.dataset_root, entry.mask.path)
 | 
			
		||||
            mask = _load_mask(self._local_path(full_path))
 | 
			
		||||
 | 
			
		||||
            if mask.shape[-2:] != entry.image.size:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
 | 
			
		||||
 | 
			
		||||
            if self.box_crop:
 | 
			
		||||
                clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
 | 
			
		||||
                    _get_clamp_bbox(
 | 
			
		||||
                        bbox_xywh,
 | 
			
		||||
                        image_path=entry.image.path,
 | 
			
		||||
                        box_crop_context=self.box_crop_context,
 | 
			
		||||
                    ),
 | 
			
		||||
                    image_size_hw=tuple(mask.shape[-2:]),
 | 
			
		||||
                )
 | 
			
		||||
                crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
 | 
			
		||||
 | 
			
		||||
                mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
 | 
			
		||||
 | 
			
		||||
            fg_probability, _, _ = self._resize_image(mask, mode="nearest")
 | 
			
		||||
 | 
			
		||||
        return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
 | 
			
		||||
 | 
			
		||||
    def _load_crop_images(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
        clamp_bbox_xyxy: Optional[torch.Tensor],
 | 
			
		||||
    ) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
 | 
			
		||||
        assert self.dataset_root is not None and entry.image is not None
 | 
			
		||||
        path = os.path.join(self.dataset_root, entry.image.path)
 | 
			
		||||
        image_rgb = _load_image(self._local_path(path))
 | 
			
		||||
 | 
			
		||||
        if image_rgb.shape[-2:] != entry.image.size:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.box_crop:
 | 
			
		||||
            assert clamp_bbox_xyxy is not None
 | 
			
		||||
            image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
 | 
			
		||||
 | 
			
		||||
        image_rgb, scale, mask_crop = self._resize_image(image_rgb)
 | 
			
		||||
 | 
			
		||||
        if self.mask_images:
 | 
			
		||||
            assert fg_probability is not None
 | 
			
		||||
            image_rgb *= fg_probability
 | 
			
		||||
 | 
			
		||||
        return image_rgb, path, mask_crop, scale
 | 
			
		||||
 | 
			
		||||
    def _load_mask_depth(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        clamp_bbox_xyxy: Optional[torch.Tensor],
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
    ) -> Tuple[torch.Tensor, str, torch.Tensor]:
 | 
			
		||||
        entry_depth = entry.depth
 | 
			
		||||
        assert entry_depth is not None
 | 
			
		||||
        path = os.path.join(self.dataset_root, entry_depth.path)
 | 
			
		||||
        depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
 | 
			
		||||
 | 
			
		||||
        if self.box_crop:
 | 
			
		||||
            assert clamp_bbox_xyxy is not None
 | 
			
		||||
            depth_bbox_xyxy = _rescale_bbox(
 | 
			
		||||
                clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
 | 
			
		||||
            )
 | 
			
		||||
            depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
 | 
			
		||||
 | 
			
		||||
        depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
 | 
			
		||||
 | 
			
		||||
        if self.mask_depths:
 | 
			
		||||
            assert fg_probability is not None
 | 
			
		||||
            depth_map *= fg_probability
 | 
			
		||||
 | 
			
		||||
        if self.load_depth_masks:
 | 
			
		||||
            assert entry_depth.mask_path is not None
 | 
			
		||||
            mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
 | 
			
		||||
            depth_mask = _load_depth_mask(self._local_path(mask_path))
 | 
			
		||||
 | 
			
		||||
            if self.box_crop:
 | 
			
		||||
                assert clamp_bbox_xyxy is not None
 | 
			
		||||
                depth_mask_bbox_xyxy = _rescale_bbox(
 | 
			
		||||
                    clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
 | 
			
		||||
                )
 | 
			
		||||
                depth_mask = _crop_around_box(
 | 
			
		||||
                    depth_mask, depth_mask_bbox_xyxy, mask_path
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
 | 
			
		||||
        else:
 | 
			
		||||
            depth_mask = torch.ones_like(depth_map)
 | 
			
		||||
 | 
			
		||||
        return depth_map, path, depth_mask
 | 
			
		||||
 | 
			
		||||
    def _get_pytorch3d_camera(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        scale: float,
 | 
			
		||||
        clamp_bbox_xyxy: Optional[torch.Tensor],
 | 
			
		||||
    ) -> PerspectiveCameras:
 | 
			
		||||
        entry_viewpoint = entry.viewpoint
 | 
			
		||||
        assert entry_viewpoint is not None
 | 
			
		||||
        # principal point and focal length
 | 
			
		||||
        principal_point = torch.tensor(
 | 
			
		||||
            entry_viewpoint.principal_point, dtype=torch.float
 | 
			
		||||
        )
 | 
			
		||||
        focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
 | 
			
		||||
 | 
			
		||||
        half_image_size_wh_orig = (
 | 
			
		||||
            torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # first, we convert from the dataset's NDC convention to pixels
 | 
			
		||||
        format = entry_viewpoint.intrinsics_format
 | 
			
		||||
        if format.lower() == "ndc_norm_image_bounds":
 | 
			
		||||
            # this is e.g. currently used in CO3D for storing intrinsics
 | 
			
		||||
            rescale = half_image_size_wh_orig
 | 
			
		||||
        elif format.lower() == "ndc_isotropic":
 | 
			
		||||
            rescale = half_image_size_wh_orig.min()
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Unknown intrinsics format: {format}")
 | 
			
		||||
 | 
			
		||||
        # principal point and focal length in pixels
 | 
			
		||||
        principal_point_px = half_image_size_wh_orig - principal_point * rescale
 | 
			
		||||
        focal_length_px = focal_length * rescale
 | 
			
		||||
        if self.box_crop:
 | 
			
		||||
            assert clamp_bbox_xyxy is not None
 | 
			
		||||
            principal_point_px -= clamp_bbox_xyxy[:2]
 | 
			
		||||
 | 
			
		||||
        # now, convert from pixels to PyTorch3D v0.5+ NDC convention
 | 
			
		||||
        if self.image_height is None or self.image_width is None:
 | 
			
		||||
            out_size = list(reversed(entry.image.size))
 | 
			
		||||
        else:
 | 
			
		||||
            out_size = [self.image_width, self.image_height]
 | 
			
		||||
 | 
			
		||||
        half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
 | 
			
		||||
        half_min_image_size_output = half_image_size_output.min()
 | 
			
		||||
 | 
			
		||||
        # rescaled principal point and focal length in ndc
 | 
			
		||||
        principal_point = (
 | 
			
		||||
            half_image_size_output - principal_point_px * scale
 | 
			
		||||
        ) / half_min_image_size_output
 | 
			
		||||
        focal_length = focal_length_px * scale / half_min_image_size_output
 | 
			
		||||
 | 
			
		||||
        return PerspectiveCameras(
 | 
			
		||||
            focal_length=focal_length[None],
 | 
			
		||||
            principal_point=principal_point[None],
 | 
			
		||||
            R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
 | 
			
		||||
            T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _load_frames(self) -> None:
 | 
			
		||||
        logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
 | 
			
		||||
        local_file = self._local_path(self.frame_annotations_file)
 | 
			
		||||
@ -853,35 +630,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self._seq_to_idx = seq_to_idx
 | 
			
		||||
 | 
			
		||||
    def _resize_image(
 | 
			
		||||
        self, image, mode="bilinear"
 | 
			
		||||
    ) -> Tuple[torch.Tensor, float, torch.Tensor]:
 | 
			
		||||
        image_height, image_width = self.image_height, self.image_width
 | 
			
		||||
        if image_height is None or image_width is None:
 | 
			
		||||
            # skip the resizing
 | 
			
		||||
            imre_ = torch.from_numpy(image)
 | 
			
		||||
            return imre_, 1.0, torch.ones_like(imre_[:1])
 | 
			
		||||
        # takes numpy array, returns pytorch tensor
 | 
			
		||||
        minscale = min(
 | 
			
		||||
            image_height / image.shape[-2],
 | 
			
		||||
            image_width / image.shape[-1],
 | 
			
		||||
        )
 | 
			
		||||
        imre = torch.nn.functional.interpolate(
 | 
			
		||||
            torch.from_numpy(image)[None],
 | 
			
		||||
            scale_factor=minscale,
 | 
			
		||||
            mode=mode,
 | 
			
		||||
            align_corners=False if mode == "bilinear" else None,
 | 
			
		||||
            recompute_scale_factor=True,
 | 
			
		||||
        )[0]
 | 
			
		||||
        # pyre-fixme[19]: Expected 1 positional argument.
 | 
			
		||||
        imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
 | 
			
		||||
        imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
 | 
			
		||||
        # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
 | 
			
		||||
        # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
 | 
			
		||||
        mask = torch.zeros(1, self.image_height, self.image_width)
 | 
			
		||||
        mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
 | 
			
		||||
        return imre_, minscale, mask
 | 
			
		||||
 | 
			
		||||
    def _local_path(self, path: str) -> str:
 | 
			
		||||
        if self.path_manager is None:
 | 
			
		||||
            return path
 | 
			
		||||
@ -918,169 +666,3 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
 | 
			
		||||
def _seq_name_to_seed(seq_name) -> int:
 | 
			
		||||
    return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_image(path) -> np.ndarray:
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        im = np.array(pil_im.convert("RGB"))
 | 
			
		||||
    im = im.transpose((2, 0, 1))
 | 
			
		||||
    im = im.astype(np.float32) / 255.0
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_16big_png_depth(depth_png) -> np.ndarray:
 | 
			
		||||
    with Image.open(depth_png) as depth_pil:
 | 
			
		||||
        # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
 | 
			
		||||
        # we cast it to uint16, then reinterpret as float16, then cast to float32
 | 
			
		||||
        depth = (
 | 
			
		||||
            np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
 | 
			
		||||
            .astype(np.float32)
 | 
			
		||||
            .reshape((depth_pil.size[1], depth_pil.size[0]))
 | 
			
		||||
        )
 | 
			
		||||
    return depth
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_1bit_png_mask(file: str) -> np.ndarray:
 | 
			
		||||
    with Image.open(file) as pil_im:
 | 
			
		||||
        mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
 | 
			
		||||
    return mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_depth_mask(path: str) -> np.ndarray:
 | 
			
		||||
    if not path.lower().endswith(".png"):
 | 
			
		||||
        raise ValueError('unsupported depth mask file name "%s"' % path)
 | 
			
		||||
    m = _load_1bit_png_mask(path)
 | 
			
		||||
    return m[None]  # fake feature channel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_depth(path, scale_adjustment) -> np.ndarray:
 | 
			
		||||
    if not path.lower().endswith(".png"):
 | 
			
		||||
        raise ValueError('unsupported depth file name "%s"' % path)
 | 
			
		||||
 | 
			
		||||
    d = _load_16big_png_depth(path) * scale_adjustment
 | 
			
		||||
    d[~np.isfinite(d)] = 0.0
 | 
			
		||||
    return d[None]  # fake feature channel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_mask(path) -> np.ndarray:
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        mask = np.array(pil_im)
 | 
			
		||||
    mask = mask.astype(np.float32) / 255.0
 | 
			
		||||
    return mask[None]  # fake feature channel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_1d_bounds(arr) -> Tuple[int, int]:
 | 
			
		||||
    nz = np.flatnonzero(arr)
 | 
			
		||||
    return nz[0], nz[-1] + 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_bbox_from_mask(
 | 
			
		||||
    mask, thr, decrease_quant: float = 0.05
 | 
			
		||||
) -> Tuple[int, int, int, int]:
 | 
			
		||||
    # bbox in xywh
 | 
			
		||||
    masks_for_box = np.zeros_like(mask)
 | 
			
		||||
    while masks_for_box.sum() <= 1.0:
 | 
			
		||||
        masks_for_box = (mask > thr).astype(np.float32)
 | 
			
		||||
        thr -= decrease_quant
 | 
			
		||||
    if thr <= 0.0:
 | 
			
		||||
        warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
 | 
			
		||||
 | 
			
		||||
    x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
 | 
			
		||||
    y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
 | 
			
		||||
 | 
			
		||||
    return x0, y0, x1 - x0, y1 - y0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_clamp_bbox(
 | 
			
		||||
    bbox: torch.Tensor,
 | 
			
		||||
    box_crop_context: float = 0.0,
 | 
			
		||||
    image_path: str = "",
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    # box_crop_context: rate of expansion for bbox
 | 
			
		||||
    # returns possibly expanded bbox xyxy as float
 | 
			
		||||
 | 
			
		||||
    bbox = bbox.clone()  # do not edit bbox in place
 | 
			
		||||
 | 
			
		||||
    # increase box size
 | 
			
		||||
    if box_crop_context > 0.0:
 | 
			
		||||
        c = box_crop_context
 | 
			
		||||
        bbox = bbox.float()
 | 
			
		||||
        bbox[0] -= bbox[2] * c / 2
 | 
			
		||||
        bbox[1] -= bbox[3] * c / 2
 | 
			
		||||
        bbox[2] += bbox[2] * c
 | 
			
		||||
        bbox[3] += bbox[3] * c
 | 
			
		||||
 | 
			
		||||
    if (bbox[2:] <= 1.0).any():
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"squashed image {image_path}!! The bounding box contains no pixels."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    bbox[2:] = torch.clamp(bbox[2:], 2)  # set min height, width to 2 along both axes
 | 
			
		||||
    bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
 | 
			
		||||
 | 
			
		||||
    return bbox_xyxy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _crop_around_box(tensor, bbox, impath: str = ""):
 | 
			
		||||
    # bbox is xyxy, where the upper bound is corrected with +1
 | 
			
		||||
    bbox = _clamp_box_to_image_bounds_and_round(
 | 
			
		||||
        bbox,
 | 
			
		||||
        image_size_hw=tensor.shape[-2:],
 | 
			
		||||
    )
 | 
			
		||||
    tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
 | 
			
		||||
    assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
 | 
			
		||||
    return tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _clamp_box_to_image_bounds_and_round(
 | 
			
		||||
    bbox_xyxy: torch.Tensor,
 | 
			
		||||
    image_size_hw: Tuple[int, int],
 | 
			
		||||
) -> torch.LongTensor:
 | 
			
		||||
    bbox_xyxy = bbox_xyxy.clone()
 | 
			
		||||
    bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
 | 
			
		||||
    bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
 | 
			
		||||
    if not isinstance(bbox_xyxy, torch.LongTensor):
 | 
			
		||||
        bbox_xyxy = bbox_xyxy.round().long()
 | 
			
		||||
    return bbox_xyxy  # pyre-ignore [7]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
 | 
			
		||||
    assert bbox is not None
 | 
			
		||||
    assert np.prod(orig_res) > 1e-8
 | 
			
		||||
    # average ratio of dimensions
 | 
			
		||||
    rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
 | 
			
		||||
    return bbox * rel_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    wh = xyxy[2:] - xyxy[:2]
 | 
			
		||||
    xywh = torch.cat([xyxy[:2], wh])
 | 
			
		||||
    return xywh
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _bbox_xywh_to_xyxy(
 | 
			
		||||
    xywh: torch.Tensor, clamp_size: Optional[int] = None
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    xyxy = xywh.clone()
 | 
			
		||||
    if clamp_size is not None:
 | 
			
		||||
        xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
 | 
			
		||||
    xyxy[2:] += xyxy[:2]
 | 
			
		||||
    return xyxy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _safe_as_tensor(data, dtype):
 | 
			
		||||
    if data is None:
 | 
			
		||||
        return None
 | 
			
		||||
    return torch.tensor(data, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NOTE this cache is per-worker; they are implemented as processes.
 | 
			
		||||
# each batch is loaded and collated by a single worker;
 | 
			
		||||
# since sequences tend to co-occur within batches, this is useful.
 | 
			
		||||
@functools.lru_cache(maxsize=256)
 | 
			
		||||
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
 | 
			
		||||
    pcl = IO().load_pointcloud(pcl_path)
 | 
			
		||||
    if max_points > 0:
 | 
			
		||||
        pcl = pcl.subsample(max_points)
 | 
			
		||||
 | 
			
		||||
    return pcl
 | 
			
		||||
 | 
			
		||||
@ -20,8 +20,9 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
 | 
			
		||||
 | 
			
		||||
from .dataset_base import DatasetBase, FrameData
 | 
			
		||||
from .dataset_base import DatasetBase
 | 
			
		||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
 | 
			
		||||
from .frame_data import FrameData
 | 
			
		||||
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
 | 
			
		||||
 | 
			
		||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
 | 
			
		||||
@ -69,7 +70,8 @@ class SingleSceneDataset(DatasetBase, Configurable):
 | 
			
		||||
            sequence_name=_SINGLE_SEQUENCE_NAME,
 | 
			
		||||
            sequence_category=self.object_name,
 | 
			
		||||
            camera=pose,
 | 
			
		||||
            image_size_hw=torch.tensor(image.shape[1:]),
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            image_size_hw=torch.tensor(image.shape[1:], dtype=torch.long),
 | 
			
		||||
            image_rgb=image,
 | 
			
		||||
            fg_probability=fg_probability,
 | 
			
		||||
            frame_type=frame_type,
 | 
			
		||||
 | 
			
		||||
@ -55,6 +55,8 @@ class MaskAnnotation:
 | 
			
		||||
    path: str
 | 
			
		||||
    # (soft) number of pixels in the mask; sum(Prob(fg | pixel))
 | 
			
		||||
    mass: Optional[float] = None
 | 
			
		||||
    # tight bounding box around the foreground mask
 | 
			
		||||
    bounding_box_xywh: Optional[Tuple[float, float, float, float]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,18 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
import functools
 | 
			
		||||
import warnings
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import List, Optional, Tuple, TypeVar, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from pytorch3d.io import IO
 | 
			
		||||
from pytorch3d.renderer.cameras import PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
 | 
			
		||||
DATASET_TYPE_TRAIN = "train"
 | 
			
		||||
DATASET_TYPE_TEST = "test"
 | 
			
		||||
@ -16,6 +24,26 @@ DATASET_TYPE_KNOWN = "known"
 | 
			
		||||
DATASET_TYPE_UNKNOWN = "unseen"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericWorkaround:
 | 
			
		||||
    """
 | 
			
		||||
    OmegaConf.structured has a weirdness when you try to apply
 | 
			
		||||
    it to a dataclass whose first base class is a Generic which is not
 | 
			
		||||
    Dict. The issue is with a function called get_dict_key_value_types
 | 
			
		||||
    in omegaconf/_utils.py.
 | 
			
		||||
    For example this fails:
 | 
			
		||||
 | 
			
		||||
        @dataclass(eq=False)
 | 
			
		||||
        class D(torch.utils.data.Dataset[int]):
 | 
			
		||||
            a: int = 3
 | 
			
		||||
 | 
			
		||||
        OmegaConf.structured(D)
 | 
			
		||||
 | 
			
		||||
    We avoid the problem by adding this class as an extra base class.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_known_frame_scalar(frame_type: str) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Given a single frame type corresponding to a single frame, return whether
 | 
			
		||||
@ -52,3 +80,286 @@ def is_train_frame(
 | 
			
		||||
        dtype=torch.bool,
 | 
			
		||||
        device=device,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_bbox_from_mask(
 | 
			
		||||
    mask: np.ndarray, thr: float, decrease_quant: float = 0.05
 | 
			
		||||
) -> Tuple[int, int, int, int]:
 | 
			
		||||
    # bbox in xywh
 | 
			
		||||
    masks_for_box = np.zeros_like(mask)
 | 
			
		||||
    while masks_for_box.sum() <= 1.0:
 | 
			
		||||
        masks_for_box = (mask > thr).astype(np.float32)
 | 
			
		||||
        thr -= decrease_quant
 | 
			
		||||
    if thr <= 0.0:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
 | 
			
		||||
    y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
 | 
			
		||||
 | 
			
		||||
    return x0, y0, x1 - x0, y1 - y0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def crop_around_box(
 | 
			
		||||
    tensor: torch.Tensor, bbox: torch.Tensor, impath: str = ""
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    # bbox is xyxy, where the upper bound is corrected with +1
 | 
			
		||||
    bbox = clamp_box_to_image_bounds_and_round(
 | 
			
		||||
        bbox,
 | 
			
		||||
        image_size_hw=tuple(tensor.shape[-2:]),
 | 
			
		||||
    )
 | 
			
		||||
    tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
 | 
			
		||||
    assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
 | 
			
		||||
    return tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clamp_box_to_image_bounds_and_round(
 | 
			
		||||
    bbox_xyxy: torch.Tensor,
 | 
			
		||||
    image_size_hw: Tuple[int, int],
 | 
			
		||||
) -> torch.LongTensor:
 | 
			
		||||
    bbox_xyxy = bbox_xyxy.clone()
 | 
			
		||||
    bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
 | 
			
		||||
    bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
 | 
			
		||||
    if not isinstance(bbox_xyxy, torch.LongTensor):
 | 
			
		||||
        bbox_xyxy = bbox_xyxy.round().long()
 | 
			
		||||
    return bbox_xyxy  # pyre-ignore [7]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T", bound=torch.Tensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bbox_xyxy_to_xywh(xyxy: T) -> T:
 | 
			
		||||
    wh = xyxy[2:] - xyxy[:2]
 | 
			
		||||
    xywh = torch.cat([xyxy[:2], wh])
 | 
			
		||||
    return xywh  # pyre-ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_clamp_bbox(
 | 
			
		||||
    bbox: torch.Tensor,
 | 
			
		||||
    box_crop_context: float = 0.0,
 | 
			
		||||
    image_path: str = "",
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    # box_crop_context: rate of expansion for bbox
 | 
			
		||||
    # returns possibly expanded bbox xyxy as float
 | 
			
		||||
 | 
			
		||||
    bbox = bbox.clone()  # do not edit bbox in place
 | 
			
		||||
 | 
			
		||||
    # increase box size
 | 
			
		||||
    if box_crop_context > 0.0:
 | 
			
		||||
        c = box_crop_context
 | 
			
		||||
        bbox = bbox.float()
 | 
			
		||||
        bbox[0] -= bbox[2] * c / 2
 | 
			
		||||
        bbox[1] -= bbox[3] * c / 2
 | 
			
		||||
        bbox[2] += bbox[2] * c
 | 
			
		||||
        bbox[3] += bbox[3] * c
 | 
			
		||||
 | 
			
		||||
    if (bbox[2:] <= 1.0).any():
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"squashed image {image_path}!! The bounding box contains no pixels."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    bbox[2:] = torch.clamp(bbox[2:], 2)  # set min height, width to 2 along both axes
 | 
			
		||||
    bbox_xyxy = bbox_xywh_to_xyxy(bbox, clamp_size=2)
 | 
			
		||||
 | 
			
		||||
    return bbox_xyxy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rescale_bbox(
 | 
			
		||||
    bbox: torch.Tensor,
 | 
			
		||||
    orig_res: Union[Tuple[int, int], torch.LongTensor],
 | 
			
		||||
    new_res: Union[Tuple[int, int], torch.LongTensor],
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    assert bbox is not None
 | 
			
		||||
    assert np.prod(orig_res) > 1e-8
 | 
			
		||||
    # average ratio of dimensions
 | 
			
		||||
    # pyre-ignore
 | 
			
		||||
    rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
 | 
			
		||||
    return bbox * rel_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bbox_xywh_to_xyxy(
 | 
			
		||||
    xywh: torch.Tensor, clamp_size: Optional[int] = None
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    xyxy = xywh.clone()
 | 
			
		||||
    if clamp_size is not None:
 | 
			
		||||
        xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
 | 
			
		||||
    xyxy[2:] += xyxy[:2]
 | 
			
		||||
    return xyxy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
 | 
			
		||||
    nz = np.flatnonzero(arr)
 | 
			
		||||
    return nz[0], nz[-1] + 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def resize_image(
 | 
			
		||||
    image: Union[np.ndarray, torch.Tensor],
 | 
			
		||||
    image_height: Optional[int],
 | 
			
		||||
    image_width: Optional[int],
 | 
			
		||||
    mode: str = "bilinear",
 | 
			
		||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
    if type(image) == np.ndarray:
 | 
			
		||||
        image = torch.from_numpy(image)
 | 
			
		||||
 | 
			
		||||
    if image_height is None or image_width is None:
 | 
			
		||||
        # skip the resizing
 | 
			
		||||
        return image, 1.0, torch.ones_like(image[:1])
 | 
			
		||||
    # takes numpy array or tensor, returns pytorch tensor
 | 
			
		||||
    minscale = min(
 | 
			
		||||
        image_height / image.shape[-2],
 | 
			
		||||
        image_width / image.shape[-1],
 | 
			
		||||
    )
 | 
			
		||||
    imre = torch.nn.functional.interpolate(
 | 
			
		||||
        image[None],
 | 
			
		||||
        scale_factor=minscale,
 | 
			
		||||
        mode=mode,
 | 
			
		||||
        align_corners=False if mode == "bilinear" else None,
 | 
			
		||||
        recompute_scale_factor=True,
 | 
			
		||||
    )[0]
 | 
			
		||||
    imre_ = torch.zeros(image.shape[0], image_height, image_width)
 | 
			
		||||
    imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
 | 
			
		||||
    mask = torch.zeros(1, image_height, image_width)
 | 
			
		||||
    mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
 | 
			
		||||
    return imre_, minscale, mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_image(path: str) -> np.ndarray:
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        im = np.array(pil_im.convert("RGB"))
 | 
			
		||||
    im = im.transpose((2, 0, 1))
 | 
			
		||||
    im = im.astype(np.float32) / 255.0
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_mask(path: str) -> np.ndarray:
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        mask = np.array(pil_im)
 | 
			
		||||
    mask = mask.astype(np.float32) / 255.0
 | 
			
		||||
    return mask[None]  # fake feature channel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
 | 
			
		||||
    if not path.lower().endswith(".png"):
 | 
			
		||||
        raise ValueError('unsupported depth file name "%s"' % path)
 | 
			
		||||
 | 
			
		||||
    d = load_16big_png_depth(path) * scale_adjustment
 | 
			
		||||
    d[~np.isfinite(d)] = 0.0
 | 
			
		||||
    return d[None]  # fake feature channel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_16big_png_depth(depth_png: str) -> np.ndarray:
 | 
			
		||||
    with Image.open(depth_png) as depth_pil:
 | 
			
		||||
        # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
 | 
			
		||||
        # we cast it to uint16, then reinterpret as float16, then cast to float32
 | 
			
		||||
        depth = (
 | 
			
		||||
            np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
 | 
			
		||||
            .astype(np.float32)
 | 
			
		||||
            .reshape((depth_pil.size[1], depth_pil.size[0]))
 | 
			
		||||
        )
 | 
			
		||||
    return depth
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_1bit_png_mask(file: str) -> np.ndarray:
 | 
			
		||||
    with Image.open(file) as pil_im:
 | 
			
		||||
        mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
 | 
			
		||||
    return mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_depth_mask(path: str) -> np.ndarray:
 | 
			
		||||
    if not path.lower().endswith(".png"):
 | 
			
		||||
        raise ValueError('unsupported depth mask file name "%s"' % path)
 | 
			
		||||
    m = load_1bit_png_mask(path)
 | 
			
		||||
    return m[None]  # fake feature channel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def safe_as_tensor(data, dtype):
 | 
			
		||||
    return torch.tensor(data, dtype=dtype) if data is not None else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_ndc_to_pixels(
 | 
			
		||||
    focal_length: torch.Tensor,
 | 
			
		||||
    principal_point: torch.Tensor,
 | 
			
		||||
    image_size_wh: torch.Tensor,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    half_image_size = image_size_wh / 2
 | 
			
		||||
    rescale = half_image_size.min()
 | 
			
		||||
    principal_point_px = half_image_size - principal_point * rescale
 | 
			
		||||
    focal_length_px = focal_length * rescale
 | 
			
		||||
    return focal_length_px, principal_point_px
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_pixels_to_ndc(
 | 
			
		||||
    focal_length_px: torch.Tensor,
 | 
			
		||||
    principal_point_px: torch.Tensor,
 | 
			
		||||
    image_size_wh: torch.Tensor,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    half_image_size = image_size_wh / 2
 | 
			
		||||
    rescale = half_image_size.min()
 | 
			
		||||
    principal_point = (half_image_size - principal_point_px) / rescale
 | 
			
		||||
    focal_length = focal_length_px / rescale
 | 
			
		||||
    return focal_length, principal_point
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def adjust_camera_to_bbox_crop_(
 | 
			
		||||
    camera: PerspectiveCameras,
 | 
			
		||||
    image_size_wh: torch.Tensor,
 | 
			
		||||
    clamp_bbox_xywh: torch.Tensor,
 | 
			
		||||
) -> None:
 | 
			
		||||
    if len(camera) != 1:
 | 
			
		||||
        raise ValueError("Adjusting currently works with singleton cameras camera only")
 | 
			
		||||
 | 
			
		||||
    focal_length_px, principal_point_px = _convert_ndc_to_pixels(
 | 
			
		||||
        camera.focal_length[0],
 | 
			
		||||
        camera.principal_point[0],  # pyre-ignore
 | 
			
		||||
        image_size_wh,
 | 
			
		||||
    )
 | 
			
		||||
    principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
 | 
			
		||||
 | 
			
		||||
    focal_length, principal_point_cropped = _convert_pixels_to_ndc(
 | 
			
		||||
        focal_length_px,
 | 
			
		||||
        principal_point_px_cropped,
 | 
			
		||||
        clamp_bbox_xywh[2:],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    camera.focal_length = focal_length[None]
 | 
			
		||||
    camera.principal_point = principal_point_cropped[None]  # pyre-ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def adjust_camera_to_image_scale_(
 | 
			
		||||
    camera: PerspectiveCameras,
 | 
			
		||||
    original_size_wh: torch.Tensor,
 | 
			
		||||
    new_size_wh: torch.LongTensor,
 | 
			
		||||
) -> PerspectiveCameras:
 | 
			
		||||
    focal_length_px, principal_point_px = _convert_ndc_to_pixels(
 | 
			
		||||
        camera.focal_length[0],
 | 
			
		||||
        camera.principal_point[0],  # pyre-ignore
 | 
			
		||||
        original_size_wh,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # now scale and convert from pixels to NDC
 | 
			
		||||
    image_size_wh_output = new_size_wh.float()
 | 
			
		||||
    scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
 | 
			
		||||
    focal_length_px_scaled = focal_length_px * scale
 | 
			
		||||
    principal_point_px_scaled = principal_point_px * scale
 | 
			
		||||
 | 
			
		||||
    focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
 | 
			
		||||
        focal_length_px_scaled,
 | 
			
		||||
        principal_point_px_scaled,
 | 
			
		||||
        image_size_wh_output,
 | 
			
		||||
    )
 | 
			
		||||
    camera.focal_length = focal_length_scaled[None]
 | 
			
		||||
    camera.principal_point = principal_point_scaled[None]  # pyre-ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NOTE this cache is per-worker; they are implemented as processes.
 | 
			
		||||
# each batch is loaded and collated by a single worker;
 | 
			
		||||
# since sequences tend to co-occur within batches, this is useful.
 | 
			
		||||
@functools.lru_cache(maxsize=256)
 | 
			
		||||
def load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
 | 
			
		||||
    pcl = IO().load_pointcloud(pcl_path)
 | 
			
		||||
    if max_points > 0:
 | 
			
		||||
        pcl = pcl.subsample(max_points)
 | 
			
		||||
 | 
			
		||||
    return pcl
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,7 @@ import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
 | 
			
		||||
from pytorch3d.structures import Pointclouds
 | 
			
		||||
 | 
			
		||||
from .dataset_base import FrameData
 | 
			
		||||
from .frame_data import FrameData
 | 
			
		||||
from .json_index_dataset import JsonIndexDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Un
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
 | 
			
		||||
from pytorch3d.implicitron.tools import vis_utils
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,8 @@ from pytorch3d.implicitron.dataset.data_loader_map_provider import (
 | 
			
		||||
    DoublePoolBatchSampler,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,11 +9,19 @@ import unittest
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import (
 | 
			
		||||
    _bbox_xywh_to_xyxy,
 | 
			
		||||
    _bbox_xyxy_to_xywh,
 | 
			
		||||
    _get_bbox_from_mask,
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import (
 | 
			
		||||
    bbox_xywh_to_xyxy,
 | 
			
		||||
    bbox_xyxy_to_xywh,
 | 
			
		||||
    clamp_box_to_image_bounds_and_round,
 | 
			
		||||
    crop_around_box,
 | 
			
		||||
    get_1d_bounds,
 | 
			
		||||
    get_bbox_from_mask,
 | 
			
		||||
    get_clamp_bbox,
 | 
			
		||||
    rescale_bbox,
 | 
			
		||||
    resize_image,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,9 +39,9 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        for bbox_xywh in bbox_xywh_list:
 | 
			
		||||
            bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh)
 | 
			
		||||
            bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy)
 | 
			
		||||
            bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_)
 | 
			
		||||
            bbox_xyxy = bbox_xywh_to_xyxy(bbox_xywh)
 | 
			
		||||
            bbox_xywh_ = bbox_xyxy_to_xywh(bbox_xyxy)
 | 
			
		||||
            bbox_xyxy_ = bbox_xywh_to_xyxy(bbox_xywh_)
 | 
			
		||||
            self.assertClose(bbox_xywh_, bbox_xywh)
 | 
			
		||||
            self.assertClose(bbox_xyxy, bbox_xyxy_)
 | 
			
		||||
 | 
			
		||||
@ -47,8 +55,8 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected:
 | 
			
		||||
            self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
 | 
			
		||||
            self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
 | 
			
		||||
            self.assertClose(bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
 | 
			
		||||
            self.assertClose(bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
 | 
			
		||||
 | 
			
		||||
        clamp_amnt = 3
 | 
			
		||||
        bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor(
 | 
			
		||||
@ -61,7 +69,7 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected:
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                _bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
 | 
			
		||||
                bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
 | 
			
		||||
                bbox_xyxy_expected,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -74,5 +82,61 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            ]
 | 
			
		||||
        ).astype(np.float32)
 | 
			
		||||
        expected_bbox_xywh = [2, 1, 2, 1]
 | 
			
		||||
        bbox_xywh = _get_bbox_from_mask(mask, 0.5)
 | 
			
		||||
        bbox_xywh = get_bbox_from_mask(mask, 0.5)
 | 
			
		||||
        self.assertClose(bbox_xywh, expected_bbox_xywh)
 | 
			
		||||
 | 
			
		||||
    def test_crop_around_box(self):
 | 
			
		||||
        bbox = torch.LongTensor([0, 1, 2, 3])  # (x_min, y_min, x_max, y_max)
 | 
			
		||||
        image = torch.LongTensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0, 0, 10, 20],
 | 
			
		||||
                [10, 20, 5, 1],
 | 
			
		||||
                [10, 20, 1, 1],
 | 
			
		||||
                [5, 4, 0, 1],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        cropped = crop_around_box(image, bbox)
 | 
			
		||||
        self.assertClose(cropped, image[1:3, 0:2])
 | 
			
		||||
 | 
			
		||||
    def test_clamp_box_to_image_bounds_and_round(self):
 | 
			
		||||
        bbox = torch.LongTensor([0, 1, 10, 12])
 | 
			
		||||
        image_size = (5, 6)
 | 
			
		||||
        expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]])
 | 
			
		||||
        clamped_bbox = clamp_box_to_image_bounds_and_round(bbox, image_size)
 | 
			
		||||
        self.assertClose(clamped_bbox, expected_clamped_bbox)
 | 
			
		||||
 | 
			
		||||
    def test_get_clamp_bbox(self):
 | 
			
		||||
        bbox_xywh = torch.LongTensor([1, 1, 4, 5])
 | 
			
		||||
        clamped_bbox_xyxy = get_clamp_bbox(bbox_xywh, box_crop_context=2)
 | 
			
		||||
        # size multiplied by 2 and added coordinates
 | 
			
		||||
        self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11]))
 | 
			
		||||
 | 
			
		||||
    def test_rescale_bbox(self):
 | 
			
		||||
        bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0])
 | 
			
		||||
        original_resolution = (4, 4)
 | 
			
		||||
        new_resolution = (8, 8)  # twice bigger
 | 
			
		||||
        rescaled_bbox = rescale_bbox(bbox, original_resolution, new_resolution)
 | 
			
		||||
        self.assertClose(bbox * 2, rescaled_bbox)
 | 
			
		||||
 | 
			
		||||
    def test_get_1d_bounds(self):
 | 
			
		||||
        array = [0, 1, 2]
 | 
			
		||||
        bounds = get_1d_bounds(array)
 | 
			
		||||
        # make nonzero 1d bounds of image
 | 
			
		||||
        self.assertClose(bounds, [1, 3])
 | 
			
		||||
 | 
			
		||||
    def test_resize_image(self):
 | 
			
		||||
        image = np.random.rand(3, 300, 500)  # rgb image 300x500
 | 
			
		||||
        expected_shape = (150, 250)
 | 
			
		||||
 | 
			
		||||
        resized_image, scale, mask_crop = resize_image(
 | 
			
		||||
            image, image_height=expected_shape[0], image_width=expected_shape[1]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        original_shape = image.shape[-2:]
 | 
			
		||||
        expected_scale = min(
 | 
			
		||||
            expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(scale, expected_scale)
 | 
			
		||||
        self.assertEqual(resized_image.shape[-2:], expected_shape)
 | 
			
		||||
        self.assertEqual(mask_crop.shape[-2:], expected_shape)
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ import os
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
 | 
			
		||||
    RenderedMeshDatasetMapProvider,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -13,8 +13,10 @@ import os
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import lpips
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
 | 
			
		||||
@ -268,7 +270,7 @@ class TestEvaluation(unittest.TestCase):
 | 
			
		||||
        for metric in lower_better:
 | 
			
		||||
            m_better = eval_result[metric]
 | 
			
		||||
            m_worse = eval_result_bad[metric]
 | 
			
		||||
            if m_better != m_better or m_worse != m_worse:
 | 
			
		||||
            if np.isnan(m_better) or np.isnan(m_worse):
 | 
			
		||||
                continue  # metric is missing, i.e. NaN
 | 
			
		||||
            _assert = (
 | 
			
		||||
                self.assertLessEqual
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										224
									
								
								tests/implicitron/test_frame_data_builder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										224
									
								
								tests/implicitron/test_frame_data_builder.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,224 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import contextlib
 | 
			
		||||
import gzip
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset import types
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import (
 | 
			
		||||
    load_16big_png_depth,
 | 
			
		||||
    load_1bit_png_mask,
 | 
			
		||||
    load_depth,
 | 
			
		||||
    load_depth_mask,
 | 
			
		||||
    load_image,
 | 
			
		||||
    load_mask,
 | 
			
		||||
    safe_as_tensor,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import get_default_args
 | 
			
		||||
from pytorch3d.renderer.cameras import PerspectiveCameras
 | 
			
		||||
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
from tests.implicitron.common_resources import get_skateboard_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
 | 
			
		||||
        category = "skateboard"
 | 
			
		||||
        stack = contextlib.ExitStack()
 | 
			
		||||
        self.dataset_root, self.path_manager = stack.enter_context(
 | 
			
		||||
            get_skateboard_data()
 | 
			
		||||
        )
 | 
			
		||||
        self.addCleanup(stack.close)
 | 
			
		||||
        self.image_height = 768
 | 
			
		||||
        self.image_width = 512
 | 
			
		||||
 | 
			
		||||
        self.frame_data_builder = FrameDataBuilder(
 | 
			
		||||
            image_height=self.image_height,
 | 
			
		||||
            image_width=self.image_width,
 | 
			
		||||
            dataset_root=self.dataset_root,
 | 
			
		||||
            path_manager=self.path_manager,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # loading single frame annotation of dataset (see JsonIndexDataset._load_frames())
 | 
			
		||||
        frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
 | 
			
		||||
        local_file = self.path_manager.get_local_path(frame_file)
 | 
			
		||||
        with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
 | 
			
		||||
            frame_annots_list = types.load_dataclass(
 | 
			
		||||
                zipfile, List[types.FrameAnnotation]
 | 
			
		||||
            )
 | 
			
		||||
            self.frame_annotation = frame_annots_list[0]
 | 
			
		||||
 | 
			
		||||
        sequence_annotations_file = os.path.join(
 | 
			
		||||
            self.dataset_root, category, "sequence_annotations.jgz"
 | 
			
		||||
        )
 | 
			
		||||
        local_file = self.path_manager.get_local_path(sequence_annotations_file)
 | 
			
		||||
        with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
 | 
			
		||||
            seq_annots_list = types.load_dataclass(
 | 
			
		||||
                zipfile, List[types.SequenceAnnotation]
 | 
			
		||||
            )
 | 
			
		||||
            seq_annots = {entry.sequence_name: entry for entry in seq_annots_list}
 | 
			
		||||
            self.seq_annotation = seq_annots[self.frame_annotation.sequence_name]
 | 
			
		||||
 | 
			
		||||
        point_cloud = self.seq_annotation.point_cloud
 | 
			
		||||
        self.frame_data = FrameData(
 | 
			
		||||
            frame_number=safe_as_tensor(self.frame_annotation.frame_number, torch.long),
 | 
			
		||||
            frame_timestamp=safe_as_tensor(
 | 
			
		||||
                self.frame_annotation.frame_timestamp, torch.float
 | 
			
		||||
            ),
 | 
			
		||||
            sequence_name=self.frame_annotation.sequence_name,
 | 
			
		||||
            sequence_category=self.seq_annotation.category,
 | 
			
		||||
            camera_quality_score=safe_as_tensor(
 | 
			
		||||
                self.seq_annotation.viewpoint_quality_score, torch.float
 | 
			
		||||
            ),
 | 
			
		||||
            point_cloud_quality_score=safe_as_tensor(
 | 
			
		||||
                point_cloud.quality_score, torch.float
 | 
			
		||||
            )
 | 
			
		||||
            if point_cloud is not None
 | 
			
		||||
            else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_frame_data_builder_args(self):
 | 
			
		||||
        # test that FrameDataBuilder works with get_default_args
 | 
			
		||||
        get_default_args(FrameDataBuilder)
 | 
			
		||||
 | 
			
		||||
    def test_fix_point_cloud_path(self):
 | 
			
		||||
        """Some files in Co3Dv2 have an accidental absolute path stored."""
 | 
			
		||||
        original_path = "some_file_path"
 | 
			
		||||
        modified_path = self.frame_data_builder._fix_point_cloud_path(original_path)
 | 
			
		||||
        self.assertIn(original_path, modified_path)
 | 
			
		||||
        self.assertIn(self.frame_data_builder.dataset_root, modified_path)
 | 
			
		||||
 | 
			
		||||
    def test_load_and_adjust_frame_data(self):
 | 
			
		||||
        self.frame_data.image_size_hw = safe_as_tensor(
 | 
			
		||||
            self.frame_annotation.image.size, torch.long
 | 
			
		||||
        )
 | 
			
		||||
        self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            self.frame_data.fg_probability,
 | 
			
		||||
            self.frame_data.mask_path,
 | 
			
		||||
            self.frame_data.bbox_xywh,
 | 
			
		||||
        ) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
 | 
			
		||||
 | 
			
		||||
        self.assertIsNotNone(self.frame_data.mask_path)
 | 
			
		||||
        self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
 | 
			
		||||
        self.assertTrue(torch.is_tensor(self.frame_data.bbox_xywh))
 | 
			
		||||
        # assert bboxes shape
 | 
			
		||||
        self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            self.frame_data.image_rgb,
 | 
			
		||||
            self.frame_data.image_path,
 | 
			
		||||
        ) = self.frame_data_builder._load_images(
 | 
			
		||||
            self.frame_annotation, self.frame_data.fg_probability
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
 | 
			
		||||
        self.assertIsNotNone(self.frame_data.image_path)
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            self.frame_data.depth_map,
 | 
			
		||||
            depth_path,
 | 
			
		||||
            self.frame_data.depth_mask,
 | 
			
		||||
        ) = self.frame_data_builder._load_mask_depth(
 | 
			
		||||
            self.frame_annotation,
 | 
			
		||||
            self.frame_data.fg_probability,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.is_tensor(self.frame_data.depth_map))
 | 
			
		||||
        self.assertIsNotNone(depth_path)
 | 
			
		||||
        self.assertTrue(torch.is_tensor(self.frame_data.depth_mask))
 | 
			
		||||
 | 
			
		||||
        new_size = (self.image_height, self.image_width)
 | 
			
		||||
 | 
			
		||||
        if self.frame_data_builder.box_crop:
 | 
			
		||||
            self.frame_data.crop_by_metadata_bbox_(
 | 
			
		||||
                self.frame_data_builder.box_crop_context,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # assert image and mask shapes after resize
 | 
			
		||||
        self.frame_data.resize_frame_(
 | 
			
		||||
            new_size_hw=torch.tensor(new_size, dtype=torch.long),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.frame_data.mask_crop.shape,
 | 
			
		||||
            torch.Size([1, self.image_height, self.image_width]),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.frame_data.image_rgb.shape,
 | 
			
		||||
            torch.Size([3, self.image_height, self.image_width]),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.frame_data.mask_crop.shape,
 | 
			
		||||
            torch.Size([1, self.image_height, self.image_width]),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.frame_data.fg_probability.shape,
 | 
			
		||||
            torch.Size([1, self.image_height, self.image_width]),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.frame_data.depth_map.shape,
 | 
			
		||||
            torch.Size([1, self.image_height, self.image_width]),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.frame_data.depth_mask.shape,
 | 
			
		||||
            torch.Size([1, self.image_height, self.image_width]),
 | 
			
		||||
        )
 | 
			
		||||
        self.frame_data.camera = self.frame_data_builder._get_pytorch3d_camera(
 | 
			
		||||
            self.frame_annotation,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
 | 
			
		||||
 | 
			
		||||
    def test_load_image(self):
 | 
			
		||||
        path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
 | 
			
		||||
        local_path = self.path_manager.get_local_path(path)
 | 
			
		||||
        image = load_image(local_path)
 | 
			
		||||
        self.assertEqual(image.dtype, np.float32)
 | 
			
		||||
        self.assertLessEqual(np.max(image), 1.0)
 | 
			
		||||
        self.assertGreaterEqual(np.min(image), 0.0)
 | 
			
		||||
 | 
			
		||||
    def test_load_mask(self):
 | 
			
		||||
        path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
 | 
			
		||||
        mask = load_mask(path)
 | 
			
		||||
        self.assertEqual(mask.dtype, np.float32)
 | 
			
		||||
        self.assertLessEqual(np.max(mask), 1.0)
 | 
			
		||||
        self.assertGreaterEqual(np.min(mask), 0.0)
 | 
			
		||||
 | 
			
		||||
    def test_load_depth(self):
 | 
			
		||||
        path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
 | 
			
		||||
        depth_map = load_depth(path, self.frame_annotation.depth.scale_adjustment)
 | 
			
		||||
        self.assertEqual(depth_map.dtype, np.float32)
 | 
			
		||||
        self.assertEqual(len(depth_map.shape), 3)
 | 
			
		||||
 | 
			
		||||
    def test_load_16big_png_depth(self):
 | 
			
		||||
        path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
 | 
			
		||||
        depth_map = load_16big_png_depth(path)
 | 
			
		||||
        self.assertEqual(depth_map.dtype, np.float32)
 | 
			
		||||
        self.assertEqual(len(depth_map.shape), 2)
 | 
			
		||||
 | 
			
		||||
    def test_load_1bit_png_mask(self):
 | 
			
		||||
        mask_path = os.path.join(
 | 
			
		||||
            self.dataset_root, self.frame_annotation.depth.mask_path
 | 
			
		||||
        )
 | 
			
		||||
        mask = load_1bit_png_mask(mask_path)
 | 
			
		||||
        self.assertEqual(mask.dtype, np.float32)
 | 
			
		||||
        self.assertEqual(len(mask.shape), 2)
 | 
			
		||||
 | 
			
		||||
    def test_load_depth_mask(self):
 | 
			
		||||
        mask_path = os.path.join(
 | 
			
		||||
            self.dataset_root, self.frame_annotation.depth.mask_path
 | 
			
		||||
        )
 | 
			
		||||
        mask = load_depth_mask(mask_path)
 | 
			
		||||
        self.assertEqual(mask.dtype, np.float32)
 | 
			
		||||
        self.assertEqual(len(mask.shape), 3)
 | 
			
		||||
@ -17,7 +17,7 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torchvision
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
 | 
			
		||||
    JsonIndexDatasetMapProviderV2,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user