mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Summary: Avoid certain hardcoded paths in co3dv2 data Reviewed By: davnov134 Differential Revision: D40209309 fbshipit-source-id: 0e83a15baa47d5bd07d2d23c6048cb4522c1ccba
1013 lines
38 KiB
Python
1013 lines
38 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
import functools
|
|
import gzip
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import warnings
|
|
from collections import defaultdict
|
|
from itertools import islice
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any,
|
|
ClassVar,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TYPE_CHECKING,
|
|
Union,
|
|
)
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
|
from pytorch3d.io import IO
|
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
|
from pytorch3d.structures.pointclouds import Pointclouds
|
|
from tqdm import tqdm
|
|
|
|
from . import types
|
|
from .dataset_base import DatasetBase, FrameData
|
|
from .utils import is_known_frame_scalar
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import TypedDict
|
|
|
|
class FrameAnnotsEntry(TypedDict):
|
|
subset: Optional[str]
|
|
frame_annotation: types.FrameAnnotation
|
|
|
|
else:
|
|
FrameAnnotsEntry = dict
|
|
|
|
|
|
@registry.register
|
|
class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|
"""
|
|
A dataset with annotations in json files like the Common Objects in 3D
|
|
(CO3D) dataset.
|
|
|
|
Args:
|
|
frame_annotations_file: A zipped json file containing metadata of the
|
|
frames in the dataset, serialized List[types.FrameAnnotation].
|
|
sequence_annotations_file: A zipped json file containing metadata of the
|
|
sequences in the dataset, serialized List[types.SequenceAnnotation].
|
|
subset_lists_file: A json file containing the lists of frames corresponding
|
|
corresponding to different subsets (e.g. train/val/test) of the dataset;
|
|
format: {subset: (sequence_name, frame_id, file_path)}.
|
|
subsets: Restrict frames/sequences only to the given list of subsets
|
|
as defined in subset_lists_file (see above).
|
|
limit_to: Limit the dataset to the first #limit_to frames (after other
|
|
filters have been applied).
|
|
limit_sequences_to: Limit the dataset to the first
|
|
#limit_sequences_to sequences (after other sequence filters have been
|
|
applied but before frame-based filters).
|
|
pick_sequence: A list of sequence names to restrict the dataset to.
|
|
exclude_sequence: A list of the names of the sequences to exclude.
|
|
limit_category_to: Restrict the dataset to the given list of categories.
|
|
dataset_root: The root folder of the dataset; all the paths in jsons are
|
|
specified relative to this root (but not json paths themselves).
|
|
load_images: Enable loading the frame RGB data.
|
|
load_depths: Enable loading the frame depth maps.
|
|
load_depth_masks: Enable loading the frame depth map masks denoting the
|
|
depth values used for evaluation (the points consistent across views).
|
|
load_masks: Enable loading frame foreground masks.
|
|
load_point_clouds: Enable loading sequence-level point clouds.
|
|
max_points: Cap on the number of loaded points in the point cloud;
|
|
if reached, they are randomly sampled without replacement.
|
|
mask_images: Whether to mask the images with the loaded foreground masks;
|
|
0 value is used for background.
|
|
mask_depths: Whether to mask the depth maps with the loaded foreground
|
|
masks; 0 value is used for background.
|
|
image_height: The height of the returned images, masks, and depth maps;
|
|
aspect ratio is preserved during cropping/resizing.
|
|
image_width: The width of the returned images, masks, and depth maps;
|
|
aspect ratio is preserved during cropping/resizing.
|
|
box_crop: Enable cropping of the image around the bounding box inferred
|
|
from the foreground region of the loaded segmentation mask; masks
|
|
and depth maps are cropped accordingly; cameras are corrected.
|
|
box_crop_mask_thr: The threshold used to separate pixels into foreground
|
|
and background based on the foreground_probability mask; if no value
|
|
is greater than this threshold, the loader lowers it and repeats.
|
|
box_crop_context: The amount of additional padding added to each
|
|
dimension of the cropping bounding box, relative to box size.
|
|
remove_empty_masks: Removes the frames with no active foreground pixels
|
|
in the segmentation mask after thresholding (see box_crop_mask_thr).
|
|
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
|
|
frames in each sequences uniformly without replacement if it has
|
|
more frames than that; applied before other frame-level filters.
|
|
seed: The seed of the random generator sampling #n_frames_per_sequence
|
|
random frames per sequence.
|
|
sort_frames: Enable frame annotations sorting to group frames from the
|
|
same sequences together and order them by timestamps
|
|
eval_batches: A list of batches that form the evaluation set;
|
|
list of batch-sized lists of indices corresponding to __getitem__
|
|
of this class, thus it can be used directly as a batch sampler.
|
|
eval_batch_index:
|
|
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
|
|
A list of batches of frames described as (sequence_name, frame_idx)
|
|
that can form the evaluation set, `eval_batches` will be set from this.
|
|
|
|
"""
|
|
|
|
frame_annotations_type: ClassVar[
|
|
Type[types.FrameAnnotation]
|
|
] = types.FrameAnnotation
|
|
|
|
path_manager: Any = None
|
|
frame_annotations_file: str = ""
|
|
sequence_annotations_file: str = ""
|
|
subset_lists_file: str = ""
|
|
subsets: Optional[List[str]] = None
|
|
limit_to: int = 0
|
|
limit_sequences_to: int = 0
|
|
pick_sequence: Tuple[str, ...] = ()
|
|
exclude_sequence: Tuple[str, ...] = ()
|
|
limit_category_to: Tuple[int, ...] = ()
|
|
dataset_root: str = ""
|
|
load_images: bool = True
|
|
load_depths: bool = True
|
|
load_depth_masks: bool = True
|
|
load_masks: bool = True
|
|
load_point_clouds: bool = False
|
|
max_points: int = 0
|
|
mask_images: bool = False
|
|
mask_depths: bool = False
|
|
image_height: Optional[int] = 800
|
|
image_width: Optional[int] = 800
|
|
box_crop: bool = True
|
|
box_crop_mask_thr: float = 0.4
|
|
box_crop_context: float = 0.3
|
|
remove_empty_masks: bool = True
|
|
n_frames_per_sequence: int = -1
|
|
seed: int = 0
|
|
sort_frames: bool = False
|
|
eval_batches: Any = None
|
|
eval_batch_index: Any = None
|
|
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
|
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
|
self.subset_to_image_path = None
|
|
self._load_frames()
|
|
self._load_sequences()
|
|
if self.sort_frames:
|
|
self._sort_frames()
|
|
self._load_subset_lists()
|
|
self._filter_db() # also computes sequence indices
|
|
self._extract_and_set_eval_batches()
|
|
logger.info(str(self))
|
|
|
|
def _extract_and_set_eval_batches(self):
|
|
"""
|
|
Sets eval_batches based on input eval_batch_index.
|
|
"""
|
|
if self.eval_batch_index is not None:
|
|
if self.eval_batches is not None:
|
|
raise ValueError(
|
|
"Cannot define both eval_batch_index and eval_batches."
|
|
)
|
|
self.eval_batches = self.seq_frame_index_to_dataset_index(
|
|
self.eval_batch_index
|
|
)
|
|
|
|
def is_filtered(self):
|
|
"""
|
|
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
|
stored on the disk might be missing in the dataset object.
|
|
|
|
Returns:
|
|
is_filtered: `True` if the dataset has been filtered, else `False`.
|
|
"""
|
|
return (
|
|
self.remove_empty_masks
|
|
or self.limit_to > 0
|
|
or self.limit_sequences_to > 0
|
|
or len(self.pick_sequence) > 0
|
|
or len(self.exclude_sequence) > 0
|
|
or len(self.limit_category_to) > 0
|
|
or self.n_frames_per_sequence > 0
|
|
)
|
|
|
|
def seq_frame_index_to_dataset_index(
|
|
self,
|
|
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
|
allow_missing_indices: bool = False,
|
|
remove_missing_indices: bool = False,
|
|
) -> List[List[Union[Optional[int], int]]]:
|
|
"""
|
|
Obtain indices into the dataset object given a list of frame ids.
|
|
|
|
Args:
|
|
seq_frame_index: The list of frame ids specified as
|
|
`List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
|
|
Image paths relative to the dataset_root can be stored specified as well:
|
|
`List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
|
|
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
|
entry from `seq_frame_index` which is missing in the dataset.
|
|
Otherwise, depending on `remove_missing_indices`, either returns `None`
|
|
in place of missing entries or removes the indices of missing entries.
|
|
remove_missing_indices: Active when `allow_missing_indices=True`.
|
|
If `False`, returns `None` in place of `seq_frame_index` entries that
|
|
are not present in the dataset.
|
|
If `True` removes missing indices from the returned indices.
|
|
|
|
Returns:
|
|
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
|
"""
|
|
_dataset_seq_frame_n_index = {
|
|
seq: {
|
|
# pyre-ignore[16]
|
|
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
|
for idx in seq_idx
|
|
}
|
|
# pyre-ignore[16]
|
|
for seq, seq_idx in self._seq_to_idx.items()
|
|
}
|
|
|
|
def _get_dataset_idx(
|
|
seq_name: str, frame_no: int, path: Optional[str] = None
|
|
) -> Optional[int]:
|
|
idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
|
|
idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
|
|
if idx is None:
|
|
msg = (
|
|
f"sequence_name={seq_name} / frame_number={frame_no}"
|
|
" not in the dataset!"
|
|
)
|
|
if not allow_missing_indices:
|
|
raise IndexError(msg)
|
|
warnings.warn(msg)
|
|
return idx
|
|
if path is not None:
|
|
# Check that the loaded frame path is consistent
|
|
# with the one stored in self.frame_annots.
|
|
assert os.path.normpath(
|
|
# pyre-ignore[16]
|
|
self.frame_annots[idx]["frame_annotation"].image.path
|
|
) == os.path.normpath(
|
|
path
|
|
), f"Inconsistent frame indices {seq_name, frame_no, path}."
|
|
return idx
|
|
|
|
dataset_idx = [
|
|
[_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
|
|
for batch in seq_frame_index
|
|
]
|
|
|
|
if allow_missing_indices and remove_missing_indices:
|
|
# remove all None indices, and also batches with only None entries
|
|
valid_dataset_idx = [
|
|
[b for b in batch if b is not None] for batch in dataset_idx
|
|
]
|
|
return [ # pyre-ignore[7]
|
|
batch for batch in valid_dataset_idx if len(batch) > 0
|
|
]
|
|
|
|
return dataset_idx
|
|
|
|
def subset_from_frame_index(
|
|
self,
|
|
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
|
allow_missing_indices: bool = True,
|
|
) -> "JsonIndexDataset":
|
|
# Get the indices into the frame annots.
|
|
dataset_indices = self.seq_frame_index_to_dataset_index(
|
|
[frame_index],
|
|
allow_missing_indices=self.is_filtered() and allow_missing_indices,
|
|
)[0]
|
|
valid_dataset_indices = [i for i in dataset_indices if i is not None]
|
|
|
|
# Deep copy the whole dataset except frame_annots, which are large so we
|
|
# deep copy only the requested subset of frame_annots.
|
|
memo = {id(self.frame_annots): None} # pyre-ignore[16]
|
|
dataset_new = copy.deepcopy(self, memo)
|
|
dataset_new.frame_annots = copy.deepcopy(
|
|
[self.frame_annots[i] for i in valid_dataset_indices]
|
|
)
|
|
|
|
# This will kill all unneeded sequence annotations.
|
|
dataset_new._invalidate_indexes(filter_seq_annots=True)
|
|
|
|
# Finally annotate the frame annotations with the name of the subset
|
|
# stored in meta.
|
|
for frame_annot in dataset_new.frame_annots:
|
|
frame_annotation = frame_annot["frame_annotation"]
|
|
if frame_annotation.meta is not None:
|
|
frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
|
|
|
|
# A sanity check - this will crash in case some entries from frame_index are missing
|
|
# in dataset_new.
|
|
valid_frame_index = [
|
|
fi for fi, di in zip(frame_index, dataset_indices) if di is not None
|
|
]
|
|
dataset_new.seq_frame_index_to_dataset_index(
|
|
[valid_frame_index], allow_missing_indices=False
|
|
)
|
|
|
|
return dataset_new
|
|
|
|
def __str__(self) -> str:
|
|
# pyre-ignore[16]
|
|
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
|
|
|
def __len__(self) -> int:
|
|
# pyre-ignore[16]
|
|
return len(self.frame_annots)
|
|
|
|
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
|
return entry["subset"]
|
|
|
|
def get_all_train_cameras(self) -> CamerasBase:
|
|
"""
|
|
Returns the cameras corresponding to all the known frames.
|
|
"""
|
|
logger.info("Loading all train cameras.")
|
|
cameras = []
|
|
# pyre-ignore[16]
|
|
for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)):
|
|
frame_type = self._get_frame_type(frame_annot)
|
|
if frame_type is None:
|
|
raise ValueError("subsets not loaded")
|
|
if is_known_frame_scalar(frame_type):
|
|
cameras.append(self[frame_idx].camera)
|
|
return join_cameras_as_batch(cameras)
|
|
|
|
def __getitem__(self, index) -> FrameData:
|
|
# pyre-ignore[16]
|
|
if index >= len(self.frame_annots):
|
|
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
|
|
|
entry = self.frame_annots[index]["frame_annotation"]
|
|
# pyre-ignore[16]
|
|
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
|
frame_data = FrameData(
|
|
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
|
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
|
|
sequence_name=entry.sequence_name,
|
|
sequence_category=self.seq_annots[entry.sequence_name].category,
|
|
camera_quality_score=_safe_as_tensor(
|
|
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
|
|
torch.float,
|
|
),
|
|
point_cloud_quality_score=_safe_as_tensor(
|
|
point_cloud.quality_score, torch.float
|
|
)
|
|
if point_cloud is not None
|
|
else None,
|
|
)
|
|
|
|
# The rest of the fields are optional
|
|
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
|
|
|
|
(
|
|
frame_data.fg_probability,
|
|
frame_data.mask_path,
|
|
frame_data.bbox_xywh,
|
|
clamp_bbox_xyxy,
|
|
frame_data.crop_bbox_xywh,
|
|
) = self._load_crop_fg_probability(entry)
|
|
|
|
scale = 1.0
|
|
if self.load_images and entry.image is not None:
|
|
# original image size
|
|
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
|
|
|
|
(
|
|
frame_data.image_rgb,
|
|
frame_data.image_path,
|
|
frame_data.mask_crop,
|
|
scale,
|
|
) = self._load_crop_images(
|
|
entry, frame_data.fg_probability, clamp_bbox_xyxy
|
|
)
|
|
|
|
if self.load_depths and entry.depth is not None:
|
|
(
|
|
frame_data.depth_map,
|
|
frame_data.depth_path,
|
|
frame_data.depth_mask,
|
|
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
|
|
|
|
if entry.viewpoint is not None:
|
|
frame_data.camera = self._get_pytorch3d_camera(
|
|
entry,
|
|
scale,
|
|
clamp_bbox_xyxy,
|
|
)
|
|
|
|
if self.load_point_clouds and point_cloud is not None:
|
|
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
|
frame_data.sequence_point_cloud = _load_pointcloud(
|
|
self._local_path(pcl_path), max_points=self.max_points
|
|
)
|
|
frame_data.sequence_point_cloud_path = pcl_path
|
|
|
|
return frame_data
|
|
|
|
def _fix_point_cloud_path(self, path: str) -> str:
|
|
"""
|
|
Fix up a point cloud path from the dataset.
|
|
Some files in Co3Dv2 have an accidental absolute path stored.
|
|
"""
|
|
unwanted_prefix = (
|
|
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
|
)
|
|
if path.startswith(unwanted_prefix):
|
|
path = path[len(unwanted_prefix) :]
|
|
return os.path.join(self.dataset_root, path)
|
|
|
|
def _load_crop_fg_probability(
|
|
self, entry: types.FrameAnnotation
|
|
) -> Tuple[
|
|
Optional[torch.Tensor],
|
|
Optional[str],
|
|
Optional[torch.Tensor],
|
|
Optional[torch.Tensor],
|
|
Optional[torch.Tensor],
|
|
]:
|
|
fg_probability = None
|
|
full_path = None
|
|
bbox_xywh = None
|
|
clamp_bbox_xyxy = None
|
|
crop_box_xywh = None
|
|
|
|
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
|
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
|
mask = _load_mask(self._local_path(full_path))
|
|
|
|
if mask.shape[-2:] != entry.image.size:
|
|
raise ValueError(
|
|
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
|
|
)
|
|
|
|
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
|
|
|
if self.box_crop:
|
|
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
|
_get_clamp_bbox(
|
|
bbox_xywh,
|
|
image_path=entry.image.path,
|
|
box_crop_context=self.box_crop_context,
|
|
),
|
|
image_size_hw=tuple(mask.shape[-2:]),
|
|
)
|
|
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
|
|
|
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
|
|
|
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
|
|
|
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
|
|
|
|
def _load_crop_images(
|
|
self,
|
|
entry: types.FrameAnnotation,
|
|
fg_probability: Optional[torch.Tensor],
|
|
clamp_bbox_xyxy: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
|
|
assert self.dataset_root is not None and entry.image is not None
|
|
path = os.path.join(self.dataset_root, entry.image.path)
|
|
image_rgb = _load_image(self._local_path(path))
|
|
|
|
if image_rgb.shape[-2:] != entry.image.size:
|
|
raise ValueError(
|
|
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
|
)
|
|
|
|
if self.box_crop:
|
|
assert clamp_bbox_xyxy is not None
|
|
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
|
|
|
|
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
|
|
|
|
if self.mask_images:
|
|
assert fg_probability is not None
|
|
image_rgb *= fg_probability
|
|
|
|
return image_rgb, path, mask_crop, scale
|
|
|
|
def _load_mask_depth(
|
|
self,
|
|
entry: types.FrameAnnotation,
|
|
clamp_bbox_xyxy: Optional[torch.Tensor],
|
|
fg_probability: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
|
entry_depth = entry.depth
|
|
assert entry_depth is not None
|
|
path = os.path.join(self.dataset_root, entry_depth.path)
|
|
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
|
|
|
if self.box_crop:
|
|
assert clamp_bbox_xyxy is not None
|
|
depth_bbox_xyxy = _rescale_bbox(
|
|
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
|
|
)
|
|
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
|
|
|
|
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
|
|
|
|
if self.mask_depths:
|
|
assert fg_probability is not None
|
|
depth_map *= fg_probability
|
|
|
|
if self.load_depth_masks:
|
|
assert entry_depth.mask_path is not None
|
|
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
|
depth_mask = _load_depth_mask(self._local_path(mask_path))
|
|
|
|
if self.box_crop:
|
|
assert clamp_bbox_xyxy is not None
|
|
depth_mask_bbox_xyxy = _rescale_bbox(
|
|
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
|
|
)
|
|
depth_mask = _crop_around_box(
|
|
depth_mask, depth_mask_bbox_xyxy, mask_path
|
|
)
|
|
|
|
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
|
|
else:
|
|
depth_mask = torch.ones_like(depth_map)
|
|
|
|
return depth_map, path, depth_mask
|
|
|
|
def _get_pytorch3d_camera(
|
|
self,
|
|
entry: types.FrameAnnotation,
|
|
scale: float,
|
|
clamp_bbox_xyxy: Optional[torch.Tensor],
|
|
) -> PerspectiveCameras:
|
|
entry_viewpoint = entry.viewpoint
|
|
assert entry_viewpoint is not None
|
|
# principal point and focal length
|
|
principal_point = torch.tensor(
|
|
entry_viewpoint.principal_point, dtype=torch.float
|
|
)
|
|
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
|
|
|
|
half_image_size_wh_orig = (
|
|
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
|
|
)
|
|
|
|
# first, we convert from the dataset's NDC convention to pixels
|
|
format = entry_viewpoint.intrinsics_format
|
|
if format.lower() == "ndc_norm_image_bounds":
|
|
# this is e.g. currently used in CO3D for storing intrinsics
|
|
rescale = half_image_size_wh_orig
|
|
elif format.lower() == "ndc_isotropic":
|
|
rescale = half_image_size_wh_orig.min()
|
|
else:
|
|
raise ValueError(f"Unknown intrinsics format: {format}")
|
|
|
|
# principal point and focal length in pixels
|
|
principal_point_px = half_image_size_wh_orig - principal_point * rescale
|
|
focal_length_px = focal_length * rescale
|
|
if self.box_crop:
|
|
assert clamp_bbox_xyxy is not None
|
|
principal_point_px -= clamp_bbox_xyxy[:2]
|
|
|
|
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
|
|
if self.image_height is None or self.image_width is None:
|
|
out_size = list(reversed(entry.image.size))
|
|
else:
|
|
out_size = [self.image_width, self.image_height]
|
|
|
|
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
|
|
half_min_image_size_output = half_image_size_output.min()
|
|
|
|
# rescaled principal point and focal length in ndc
|
|
principal_point = (
|
|
half_image_size_output - principal_point_px * scale
|
|
) / half_min_image_size_output
|
|
focal_length = focal_length_px * scale / half_min_image_size_output
|
|
|
|
return PerspectiveCameras(
|
|
focal_length=focal_length[None],
|
|
principal_point=principal_point[None],
|
|
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
|
|
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
|
)
|
|
|
|
def _load_frames(self) -> None:
|
|
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
|
|
local_file = self._local_path(self.frame_annotations_file)
|
|
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
|
frame_annots_list = types.load_dataclass(
|
|
zipfile, List[self.frame_annotations_type]
|
|
)
|
|
if not frame_annots_list:
|
|
raise ValueError("Empty dataset!")
|
|
# pyre-ignore[16]
|
|
self.frame_annots = [
|
|
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
|
]
|
|
|
|
def _load_sequences(self) -> None:
|
|
logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
|
|
local_file = self._local_path(self.sequence_annotations_file)
|
|
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
|
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
|
if not seq_annots:
|
|
raise ValueError("Empty sequences file!")
|
|
# pyre-ignore[16]
|
|
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
|
|
|
def _load_subset_lists(self) -> None:
|
|
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
|
|
if not self.subset_lists_file:
|
|
return
|
|
|
|
with open(self._local_path(self.subset_lists_file), "r") as f:
|
|
subset_to_seq_frame = json.load(f)
|
|
|
|
frame_path_to_subset = {
|
|
path: subset
|
|
for subset, frames in subset_to_seq_frame.items()
|
|
for _, _, path in frames
|
|
}
|
|
# pyre-ignore[16]
|
|
for frame in self.frame_annots:
|
|
frame["subset"] = frame_path_to_subset.get(
|
|
frame["frame_annotation"].image.path, None
|
|
)
|
|
if frame["subset"] is None:
|
|
warnings.warn(
|
|
"Subset lists are given but don't include "
|
|
+ frame["frame_annotation"].image.path
|
|
)
|
|
|
|
def _sort_frames(self) -> None:
|
|
# Sort frames to have them grouped by sequence, ordered by timestamp
|
|
# pyre-ignore[16]
|
|
self.frame_annots = sorted(
|
|
self.frame_annots,
|
|
key=lambda f: (
|
|
f["frame_annotation"].sequence_name,
|
|
f["frame_annotation"].frame_timestamp or 0,
|
|
),
|
|
)
|
|
|
|
def _filter_db(self) -> None:
|
|
if self.remove_empty_masks:
|
|
logger.info("Removing images with empty masks.")
|
|
# pyre-ignore[16]
|
|
old_len = len(self.frame_annots)
|
|
|
|
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
|
|
|
def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
|
|
mask = frame_annot.mask
|
|
if mask is None:
|
|
return False
|
|
if mask.mass is None:
|
|
raise ValueError(msg)
|
|
return mask.mass > 1
|
|
|
|
self.frame_annots = [
|
|
frame
|
|
for frame in self.frame_annots
|
|
if positive_mass(frame["frame_annotation"])
|
|
]
|
|
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
|
|
|
|
# this has to be called after joining with categories!!
|
|
subsets = self.subsets
|
|
if subsets:
|
|
if not self.subset_lists_file:
|
|
raise ValueError(
|
|
"Subset filter is on but subset_lists_file was not given"
|
|
)
|
|
|
|
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
|
|
|
|
# truncate the list of subsets to the valid one
|
|
self.frame_annots = [
|
|
entry for entry in self.frame_annots if entry["subset"] in subsets
|
|
]
|
|
if len(self.frame_annots) == 0:
|
|
raise ValueError(f"There are no frames in the '{subsets}' subsets!")
|
|
|
|
self._invalidate_indexes(filter_seq_annots=True)
|
|
|
|
if len(self.limit_category_to) > 0:
|
|
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
|
# pyre-ignore[16]
|
|
self.seq_annots = {
|
|
name: entry
|
|
for name, entry in self.seq_annots.items()
|
|
if entry.category in self.limit_category_to
|
|
}
|
|
|
|
# sequence filters
|
|
for prefix in ("pick", "exclude"):
|
|
orig_len = len(self.seq_annots)
|
|
attr = f"{prefix}_sequence"
|
|
arr = getattr(self, attr)
|
|
if len(arr) > 0:
|
|
logger.info(f"{attr}: {str(arr)}")
|
|
self.seq_annots = {
|
|
name: entry
|
|
for name, entry in self.seq_annots.items()
|
|
if (name in arr) == (prefix == "pick")
|
|
}
|
|
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
|
|
|
|
if self.limit_sequences_to > 0:
|
|
self.seq_annots = dict(
|
|
islice(self.seq_annots.items(), self.limit_sequences_to)
|
|
)
|
|
|
|
# retain only frames from retained sequences
|
|
self.frame_annots = [
|
|
f
|
|
for f in self.frame_annots
|
|
if f["frame_annotation"].sequence_name in self.seq_annots
|
|
]
|
|
|
|
self._invalidate_indexes()
|
|
|
|
if self.n_frames_per_sequence > 0:
|
|
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
|
keep_idx = []
|
|
# pyre-ignore[16]
|
|
for seq, seq_indices in self._seq_to_idx.items():
|
|
# infer the seed from the sequence name, this is reproducible
|
|
# and makes the selection differ for different sequences
|
|
seed = _seq_name_to_seed(seq) + self.seed
|
|
seq_idx_shuffled = random.Random(seed).sample(
|
|
sorted(seq_indices), len(seq_indices)
|
|
)
|
|
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
|
|
|
|
logger.info(
|
|
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
|
|
)
|
|
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
|
|
self._invalidate_indexes(filter_seq_annots=False)
|
|
# sequences are not decimated, so self.seq_annots is valid
|
|
|
|
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
|
|
logger.info(
|
|
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
|
|
)
|
|
self.frame_annots = self.frame_annots[: self.limit_to]
|
|
self._invalidate_indexes(filter_seq_annots=True)
|
|
|
|
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
|
|
# update _seq_to_idx and filter seq_meta according to frame_annots change
|
|
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
|
|
self._invalidate_seq_to_idx()
|
|
|
|
if filter_seq_annots:
|
|
# pyre-ignore[16]
|
|
self.seq_annots = {
|
|
k: v
|
|
for k, v in self.seq_annots.items()
|
|
# pyre-ignore[16]
|
|
if k in self._seq_to_idx
|
|
}
|
|
|
|
def _invalidate_seq_to_idx(self) -> None:
|
|
seq_to_idx = defaultdict(list)
|
|
# pyre-ignore[16]
|
|
for idx, entry in enumerate(self.frame_annots):
|
|
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
|
# pyre-ignore[16]
|
|
self._seq_to_idx = seq_to_idx
|
|
|
|
def _resize_image(
|
|
self, image, mode="bilinear"
|
|
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
|
image_height, image_width = self.image_height, self.image_width
|
|
if image_height is None or image_width is None:
|
|
# skip the resizing
|
|
imre_ = torch.from_numpy(image)
|
|
return imre_, 1.0, torch.ones_like(imre_[:1])
|
|
# takes numpy array, returns pytorch tensor
|
|
minscale = min(
|
|
image_height / image.shape[-2],
|
|
image_width / image.shape[-1],
|
|
)
|
|
imre = torch.nn.functional.interpolate(
|
|
torch.from_numpy(image)[None],
|
|
scale_factor=minscale,
|
|
mode=mode,
|
|
align_corners=False if mode == "bilinear" else None,
|
|
recompute_scale_factor=True,
|
|
)[0]
|
|
# pyre-fixme[19]: Expected 1 positional argument.
|
|
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
|
|
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
|
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
|
|
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
|
|
mask = torch.zeros(1, self.image_height, self.image_width)
|
|
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
|
|
return imre_, minscale, mask
|
|
|
|
def _local_path(self, path: str) -> str:
|
|
if self.path_manager is None:
|
|
return path
|
|
return self.path_manager.get_local_path(path)
|
|
|
|
def get_frame_numbers_and_timestamps(
|
|
self, idxs: Sequence[int]
|
|
) -> List[Tuple[int, float]]:
|
|
out: List[Tuple[int, float]] = []
|
|
for idx in idxs:
|
|
# pyre-ignore[16]
|
|
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
|
out.append(
|
|
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
|
)
|
|
return out
|
|
|
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
|
return self.eval_batches
|
|
|
|
|
|
def _seq_name_to_seed(seq_name) -> int:
|
|
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
|
|
|
|
|
|
def _load_image(path) -> np.ndarray:
|
|
with Image.open(path) as pil_im:
|
|
im = np.array(pil_im.convert("RGB"))
|
|
im = im.transpose((2, 0, 1))
|
|
im = im.astype(np.float32) / 255.0
|
|
return im
|
|
|
|
|
|
def _load_16big_png_depth(depth_png) -> np.ndarray:
|
|
with Image.open(depth_png) as depth_pil:
|
|
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
|
|
# we cast it to uint16, then reinterpret as float16, then cast to float32
|
|
depth = (
|
|
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
|
|
.astype(np.float32)
|
|
.reshape((depth_pil.size[1], depth_pil.size[0]))
|
|
)
|
|
return depth
|
|
|
|
|
|
def _load_1bit_png_mask(file: str) -> np.ndarray:
|
|
with Image.open(file) as pil_im:
|
|
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
|
|
return mask
|
|
|
|
|
|
def _load_depth_mask(path: str) -> np.ndarray:
|
|
if not path.lower().endswith(".png"):
|
|
raise ValueError('unsupported depth mask file name "%s"' % path)
|
|
m = _load_1bit_png_mask(path)
|
|
return m[None] # fake feature channel
|
|
|
|
|
|
def _load_depth(path, scale_adjustment) -> np.ndarray:
|
|
if not path.lower().endswith(".png"):
|
|
raise ValueError('unsupported depth file name "%s"' % path)
|
|
|
|
d = _load_16big_png_depth(path) * scale_adjustment
|
|
d[~np.isfinite(d)] = 0.0
|
|
return d[None] # fake feature channel
|
|
|
|
|
|
def _load_mask(path) -> np.ndarray:
|
|
with Image.open(path) as pil_im:
|
|
mask = np.array(pil_im)
|
|
mask = mask.astype(np.float32) / 255.0
|
|
return mask[None] # fake feature channel
|
|
|
|
|
|
def _get_1d_bounds(arr) -> Tuple[int, int]:
|
|
nz = np.flatnonzero(arr)
|
|
return nz[0], nz[-1] + 1
|
|
|
|
|
|
def _get_bbox_from_mask(
|
|
mask, thr, decrease_quant: float = 0.05
|
|
) -> Tuple[int, int, int, int]:
|
|
# bbox in xywh
|
|
masks_for_box = np.zeros_like(mask)
|
|
while masks_for_box.sum() <= 1.0:
|
|
masks_for_box = (mask > thr).astype(np.float32)
|
|
thr -= decrease_quant
|
|
if thr <= 0.0:
|
|
warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
|
|
|
|
x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
|
|
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
|
|
|
|
return x0, y0, x1 - x0, y1 - y0
|
|
|
|
|
|
def _get_clamp_bbox(
|
|
bbox: torch.Tensor,
|
|
box_crop_context: float = 0.0,
|
|
image_path: str = "",
|
|
) -> torch.Tensor:
|
|
# box_crop_context: rate of expansion for bbox
|
|
# returns possibly expanded bbox xyxy as float
|
|
|
|
bbox = bbox.clone() # do not edit bbox in place
|
|
|
|
# increase box size
|
|
if box_crop_context > 0.0:
|
|
c = box_crop_context
|
|
bbox = bbox.float()
|
|
bbox[0] -= bbox[2] * c / 2
|
|
bbox[1] -= bbox[3] * c / 2
|
|
bbox[2] += bbox[2] * c
|
|
bbox[3] += bbox[3] * c
|
|
|
|
if (bbox[2:] <= 1.0).any():
|
|
raise ValueError(
|
|
f"squashed image {image_path}!! The bounding box contains no pixels."
|
|
)
|
|
|
|
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
|
|
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
|
|
|
return bbox_xyxy
|
|
|
|
|
|
def _crop_around_box(tensor, bbox, impath: str = ""):
|
|
# bbox is xyxy, where the upper bound is corrected with +1
|
|
bbox = _clamp_box_to_image_bounds_and_round(
|
|
bbox,
|
|
image_size_hw=tensor.shape[-2:],
|
|
)
|
|
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
|
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
|
return tensor
|
|
|
|
|
|
def _clamp_box_to_image_bounds_and_round(
|
|
bbox_xyxy: torch.Tensor,
|
|
image_size_hw: Tuple[int, int],
|
|
) -> torch.LongTensor:
|
|
bbox_xyxy = bbox_xyxy.clone()
|
|
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
|
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
|
if not isinstance(bbox_xyxy, torch.LongTensor):
|
|
bbox_xyxy = bbox_xyxy.round().long()
|
|
return bbox_xyxy # pyre-ignore [7]
|
|
|
|
|
|
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
|
assert bbox is not None
|
|
assert np.prod(orig_res) > 1e-8
|
|
# average ratio of dimensions
|
|
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
|
|
return bbox * rel_size
|
|
|
|
|
|
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
|
|
wh = xyxy[2:] - xyxy[:2]
|
|
xywh = torch.cat([xyxy[:2], wh])
|
|
return xywh
|
|
|
|
|
|
def _bbox_xywh_to_xyxy(
|
|
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
|
) -> torch.Tensor:
|
|
xyxy = xywh.clone()
|
|
if clamp_size is not None:
|
|
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
|
xyxy[2:] += xyxy[:2]
|
|
return xyxy
|
|
|
|
|
|
def _safe_as_tensor(data, dtype):
|
|
if data is None:
|
|
return None
|
|
return torch.tensor(data, dtype=dtype)
|
|
|
|
|
|
# NOTE this cache is per-worker; they are implemented as processes.
|
|
# each batch is loaded and collated by a single worker;
|
|
# since sequences tend to co-occur within batches, this is useful.
|
|
@functools.lru_cache(maxsize=256)
|
|
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
|
|
pcl = IO().load_pointcloud(pcl_path)
|
|
if max_points > 0:
|
|
pcl = pcl.subsample(max_points)
|
|
|
|
return pcl
|