mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
SQL Index Dataset
Summary: Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay. It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS. Reviewed By: bottler Differential Revision: D45086611 fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
This commit is contained in:
parent
7aeedd17a4
commit
32e1992924
@ -450,6 +450,7 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
|||||||
self,
|
self,
|
||||||
frame_annotation: types.FrameAnnotation,
|
frame_annotation: types.FrameAnnotation,
|
||||||
sequence_annotation: types.SequenceAnnotation,
|
sequence_annotation: types.SequenceAnnotation,
|
||||||
|
load_blobs: bool = True,
|
||||||
) -> FrameDataSubtype:
|
) -> FrameDataSubtype:
|
||||||
"""An abstract method to build the frame data based on raw frame/sequence
|
"""An abstract method to build the frame data based on raw frame/sequence
|
||||||
annotations, load the binary data and adjust them according to the metadata.
|
annotations, load the binary data and adjust them according to the metadata.
|
||||||
@ -465,8 +466,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
Beware that modifications of frame data are done in-place.
|
Beware that modifications of frame data are done in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_root: The root folder of the dataset; all the paths in jsons are
|
dataset_root: The root folder of the dataset; all paths in frame / sequence
|
||||||
specified relative to this root (but not json paths themselves).
|
annotations are defined w.r.t. this root. Has to be set if any of the
|
||||||
|
load_* flabs below is true.
|
||||||
load_images: Enable loading the frame RGB data.
|
load_images: Enable loading the frame RGB data.
|
||||||
load_depths: Enable loading the frame depth maps.
|
load_depths: Enable loading the frame depth maps.
|
||||||
load_depth_masks: Enable loading the frame depth map masks denoting the
|
load_depth_masks: Enable loading the frame depth map masks denoting the
|
||||||
@ -494,7 +496,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
path_manager: Optionally a PathManager for interpreting paths in a special way.
|
path_manager: Optionally a PathManager for interpreting paths in a special way.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset_root: str = ""
|
dataset_root: Optional[str] = None
|
||||||
load_images: bool = True
|
load_images: bool = True
|
||||||
load_depths: bool = True
|
load_depths: bool = True
|
||||||
load_depth_masks: bool = True
|
load_depth_masks: bool = True
|
||||||
@ -510,6 +512,25 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
box_crop_context: float = 0.3
|
box_crop_context: float = 0.3
|
||||||
path_manager: Any = None
|
path_manager: Any = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
load_any_blob = (
|
||||||
|
self.load_images
|
||||||
|
or self.load_depths
|
||||||
|
or self.load_depth_masks
|
||||||
|
or self.load_masks
|
||||||
|
or self.load_point_clouds
|
||||||
|
)
|
||||||
|
if load_any_blob and self.dataset_root is None:
|
||||||
|
raise ValueError(
|
||||||
|
"dataset_root must be set to load any blob data. "
|
||||||
|
"Make sure it is set in either FrameDataBuilder or Dataset params."
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_any_blob and not os.path.isdir(self.dataset_root): # pyre-ignore
|
||||||
|
raise ValueError(
|
||||||
|
f"dataset_root is passed but {self.dataset_root} does not exist."
|
||||||
|
)
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
frame_annotation: types.FrameAnnotation,
|
frame_annotation: types.FrameAnnotation,
|
||||||
@ -567,7 +588,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
if bbox_xywh is None and fg_mask_np is not None:
|
if bbox_xywh is None and fg_mask_np is not None:
|
||||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||||
|
|
||||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
|
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||||
|
|
||||||
if frame_annotation.image is not None:
|
if frame_annotation.image is not None:
|
||||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||||
@ -612,7 +633,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
def _load_fg_probability(
|
def _load_fg_probability(
|
||||||
self, entry: types.FrameAnnotation
|
self, entry: types.FrameAnnotation
|
||||||
) -> Tuple[np.ndarray, str]:
|
) -> Tuple[np.ndarray, str]:
|
||||||
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
|
assert self.dataset_root is not None and entry.mask is not None
|
||||||
|
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||||
fg_probability = load_mask(self._local_path(full_path))
|
fg_probability = load_mask(self._local_path(full_path))
|
||||||
if fg_probability.shape[-2:] != entry.image.size:
|
if fg_probability.shape[-2:] != entry.image.size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -647,7 +669,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
fg_probability: Optional[torch.Tensor],
|
fg_probability: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||||
entry_depth = entry.depth
|
entry_depth = entry.depth
|
||||||
assert entry_depth is not None
|
assert self.dataset_root is not None and entry_depth is not None
|
||||||
path = os.path.join(self.dataset_root, entry_depth.path)
|
path = os.path.join(self.dataset_root, entry_depth.path)
|
||||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||||
|
|
||||||
@ -657,6 +679,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
if self.load_depth_masks:
|
if self.load_depth_masks:
|
||||||
assert entry_depth.mask_path is not None
|
assert entry_depth.mask_path is not None
|
||||||
|
# pyre-ignore
|
||||||
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
||||||
depth_mask = load_depth_mask(self._local_path(mask_path))
|
depth_mask = load_depth_mask(self._local_path(mask_path))
|
||||||
else:
|
else:
|
||||||
@ -705,6 +728,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
)
|
)
|
||||||
if path.startswith(unwanted_prefix):
|
if path.startswith(unwanted_prefix):
|
||||||
path = path[len(unwanted_prefix) :]
|
path = path[len(unwanted_prefix) :]
|
||||||
|
assert self.dataset_root is not None
|
||||||
return os.path.join(self.dataset_root, path)
|
return os.path.join(self.dataset_root, path)
|
||||||
|
|
||||||
def _local_path(self, path: str) -> str:
|
def _local_path(self, path: str) -> str:
|
||||||
|
161
pytorch3d/implicitron/dataset/orm_types.py
Normal file
161
pytorch3d/implicitron/dataset/orm_types.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# This functionality requires SQLAlchemy 2.0 or later.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.types import (
|
||||||
|
DepthAnnotation,
|
||||||
|
ImageAnnotation,
|
||||||
|
MaskAnnotation,
|
||||||
|
PointCloudAnnotation,
|
||||||
|
VideoAnnotation,
|
||||||
|
ViewpointAnnotation,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy import LargeBinary
|
||||||
|
from sqlalchemy.orm import (
|
||||||
|
composite,
|
||||||
|
DeclarativeBase,
|
||||||
|
Mapped,
|
||||||
|
mapped_column,
|
||||||
|
MappedAsDataclass,
|
||||||
|
)
|
||||||
|
from sqlalchemy.types import TypeDecorator
|
||||||
|
|
||||||
|
|
||||||
|
# these produce policies to serialize structured types to blobs
|
||||||
|
def ArrayTypeFactory(shape):
|
||||||
|
class NumpyArrayType(TypeDecorator):
|
||||||
|
impl = LargeBinary
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is not None:
|
||||||
|
if value.shape != shape:
|
||||||
|
raise ValueError(f"Passed an array of wrong shape: {value.shape}")
|
||||||
|
return value.astype(np.float32).tobytes()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is not None:
|
||||||
|
return np.frombuffer(value, dtype=np.float32).reshape(shape)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return NumpyArrayType
|
||||||
|
|
||||||
|
|
||||||
|
def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
|
||||||
|
format_symbol = {
|
||||||
|
float: "f", # float32
|
||||||
|
int: "i", # int32
|
||||||
|
}[dtype]
|
||||||
|
|
||||||
|
class TupleType(TypeDecorator):
|
||||||
|
impl = LargeBinary
|
||||||
|
_format = format_symbol * math.prod(shape)
|
||||||
|
|
||||||
|
def process_bind_param(self, value, _):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(shape) > 1:
|
||||||
|
value = np.array(value, dtype=dtype).reshape(-1)
|
||||||
|
|
||||||
|
return struct.pack(TupleType._format, *value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, _):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
loaded = struct.unpack(TupleType._format, value)
|
||||||
|
if len(shape) > 1:
|
||||||
|
loaded = _rec_totuple(
|
||||||
|
np.array(loaded, dtype=dtype).reshape(shape).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
return TupleType
|
||||||
|
|
||||||
|
|
||||||
|
def _rec_totuple(t):
|
||||||
|
if isinstance(t, list):
|
||||||
|
return tuple(_rec_totuple(x) for x in t)
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class Base(MappedAsDataclass, DeclarativeBase):
|
||||||
|
"""subclasses will be converted to dataclasses"""
|
||||||
|
|
||||||
|
|
||||||
|
class SqlFrameAnnotation(Base):
|
||||||
|
__tablename__ = "frame_annots"
|
||||||
|
|
||||||
|
sequence_name: Mapped[str] = mapped_column(primary_key=True)
|
||||||
|
frame_number: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
frame_timestamp: Mapped[float] = mapped_column(index=True)
|
||||||
|
|
||||||
|
image: Mapped[ImageAnnotation] = composite(
|
||||||
|
mapped_column("_image_path"),
|
||||||
|
mapped_column("_image_size", TupleTypeFactory(int)),
|
||||||
|
)
|
||||||
|
|
||||||
|
depth: Mapped[DepthAnnotation] = composite(
|
||||||
|
mapped_column("_depth_path", nullable=True),
|
||||||
|
mapped_column("_depth_scale_adjustment", nullable=True),
|
||||||
|
mapped_column("_depth_mask_path", nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
mask: Mapped[MaskAnnotation] = composite(
|
||||||
|
mapped_column("_mask_path", nullable=True),
|
||||||
|
mapped_column("_mask_mass", index=True, nullable=True),
|
||||||
|
mapped_column(
|
||||||
|
"_mask_bounding_box_xywh",
|
||||||
|
TupleTypeFactory(float, shape=(4,)),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
viewpoint: Mapped[ViewpointAnnotation] = composite(
|
||||||
|
mapped_column(
|
||||||
|
"_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True
|
||||||
|
),
|
||||||
|
mapped_column(
|
||||||
|
"_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True
|
||||||
|
),
|
||||||
|
mapped_column(
|
||||||
|
"_viewpoint_focal_length", TupleTypeFactory(float), nullable=True
|
||||||
|
),
|
||||||
|
mapped_column(
|
||||||
|
"_viewpoint_principal_point", TupleTypeFactory(float), nullable=True
|
||||||
|
),
|
||||||
|
mapped_column("_viewpoint_intrinsics_format", nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SqlSequenceAnnotation(Base):
|
||||||
|
__tablename__ = "sequence_annots"
|
||||||
|
|
||||||
|
sequence_name: Mapped[str] = mapped_column(primary_key=True)
|
||||||
|
category: Mapped[str] = mapped_column(index=True)
|
||||||
|
|
||||||
|
video: Mapped[VideoAnnotation] = composite(
|
||||||
|
mapped_column("_video_path", nullable=True),
|
||||||
|
mapped_column("_video_length", nullable=True),
|
||||||
|
)
|
||||||
|
point_cloud: Mapped[PointCloudAnnotation] = composite(
|
||||||
|
mapped_column("_point_cloud_path", nullable=True),
|
||||||
|
mapped_column("_point_cloud_quality_score", nullable=True),
|
||||||
|
mapped_column("_point_cloud_n_points", nullable=True),
|
||||||
|
)
|
||||||
|
# the bigger the better
|
||||||
|
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)
|
735
pytorch3d/implicitron/dataset/sql_dataset.py
Normal file
735
pytorch3d/implicitron/dataset/sql_dataset.py
Normal file
@ -0,0 +1,735 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
||||||
|
FrameData,
|
||||||
|
FrameDataBuilder,
|
||||||
|
FrameDataBuilderBase,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
registry,
|
||||||
|
ReplaceableBase,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_SET_LISTS_TABLE: str = "set_lists"
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||||
|
"""
|
||||||
|
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||||
|
The length is returned after all sequence and frame filters are applied (see param
|
||||||
|
definitions below). Indices can either be ordinal in [0, len), or pairs of
|
||||||
|
(sequence_name, frame_number); with the performance of `dataset[i]` and
|
||||||
|
`dataset[sequence_name, frame_number]` being same. A faster way to get metadata only
|
||||||
|
(without blobs) is `dataset.meta[idx]` indexing; it requires box_crop==False.
|
||||||
|
With ordinal indexing, the sequences are NOT guaranteed to span contiguous index
|
||||||
|
ranges, and frame numbers are NOT guaranteed to be increasing within a sequence.
|
||||||
|
Sequence-aware batch samplers have to use `sequence_[frames|indices]_in_order`
|
||||||
|
iterators, which are efficient.
|
||||||
|
|
||||||
|
This functionality requires SQLAlchemy 2.0 or later.
|
||||||
|
|
||||||
|
Metadata-related args:
|
||||||
|
sqlite_metadata_file: A SQLite file containing frame and sequence annotation
|
||||||
|
tables (mapping to SqlFrameAnnotation and SqlSequenceAnnotation,
|
||||||
|
respectively).
|
||||||
|
dataset_root: A root directory to look for images, masks, etc. It can be
|
||||||
|
alternatively set in `frame_data_builder` args, but this takes precedence.
|
||||||
|
subset_lists_file: A JSON/sqlite file containing the lists of frames
|
||||||
|
corresponding to different subsets (e.g. train/val/test) of the dataset;
|
||||||
|
format: {subset: [(sequence_name, frame_id, file_path)]}. All entries
|
||||||
|
must be present in frame_annotation metadata table.
|
||||||
|
path_manager: a facade for non-POSIX filesystems.
|
||||||
|
subsets: Restrict frames/sequences only to the given list of subsets
|
||||||
|
as defined in subset_lists_file (see above). Applied before all other
|
||||||
|
filters.
|
||||||
|
remove_empty_masks: Removes the frames with no active foreground pixels
|
||||||
|
in the segmentation mask (needs frame_annotation.mask.mass to be set;
|
||||||
|
null values are retained).
|
||||||
|
pick_frames_sql_clause: SQL WHERE clause to constrain frame annotations
|
||||||
|
NOTE: This is a potential security risk! The string is passed to the SQL
|
||||||
|
engine verbatim. Don’t expose it to end users of your application!
|
||||||
|
pick_categories: Restrict the dataset to the given list of categories.
|
||||||
|
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||||
|
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||||
|
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
|
||||||
|
sequences (after other sequence filters have been applied but before
|
||||||
|
frame-based filters).
|
||||||
|
limit_to: Limit the dataset to the first #limit_to frames (after other
|
||||||
|
filters have been applied, except n_frames_per_sequence).
|
||||||
|
n_frames_per_sequence: If > 0, randomly samples `n_frames_per_sequence`
|
||||||
|
frames in each sequences uniformly without replacement if it has
|
||||||
|
more frames than that; applied after other frame-level filters.
|
||||||
|
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||||
|
random frames per sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||||
|
|
||||||
|
sqlite_metadata_file: str = ""
|
||||||
|
dataset_root: Optional[str] = None
|
||||||
|
subset_lists_file: str = ""
|
||||||
|
eval_batches_file: Optional[str] = None
|
||||||
|
path_manager: Any = None
|
||||||
|
subsets: Optional[List[str]] = None
|
||||||
|
remove_empty_masks: bool = True
|
||||||
|
pick_frames_sql_clause: Optional[str] = None
|
||||||
|
pick_categories: Tuple[str, ...] = ()
|
||||||
|
|
||||||
|
pick_sequences: Tuple[str, ...] = ()
|
||||||
|
exclude_sequences: Tuple[str, ...] = ()
|
||||||
|
limit_sequences_to: int = 0
|
||||||
|
limit_to: int = 0
|
||||||
|
n_frames_per_sequence: int = -1
|
||||||
|
seed: int = 0
|
||||||
|
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||||
|
# we set it manually in the constructor
|
||||||
|
# _index: pd.DataFrame = field(init=False)
|
||||||
|
|
||||||
|
frame_data_builder: FrameDataBuilderBase
|
||||||
|
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if sa.__version__ < "2.0":
|
||||||
|
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||||
|
|
||||||
|
if not self.sqlite_metadata_file:
|
||||||
|
raise ValueError("sqlite_metadata_file must be set")
|
||||||
|
|
||||||
|
if self.dataset_root:
|
||||||
|
frame_builder_type = self.frame_data_builder_class_type
|
||||||
|
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
||||||
|
"dataset_root"
|
||||||
|
] = self.dataset_root
|
||||||
|
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
# pyre-ignore
|
||||||
|
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}")
|
||||||
|
|
||||||
|
sequences = self._get_filtered_sequences_if_any()
|
||||||
|
|
||||||
|
if self.subsets:
|
||||||
|
index = self._build_index_from_subset_lists(sequences)
|
||||||
|
else:
|
||||||
|
# TODO: if self.subset_lists_file and not self.subsets, it might be faster to
|
||||||
|
# still use the concatenated lists, assuming they cover the whole dataset
|
||||||
|
index = self._build_index_from_db(sequences)
|
||||||
|
|
||||||
|
if self.n_frames_per_sequence >= 0:
|
||||||
|
index = self._stratified_sample_index(index)
|
||||||
|
|
||||||
|
if len(index) == 0:
|
||||||
|
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||||
|
|
||||||
|
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
||||||
|
|
||||||
|
self.eval_batches = None # pyre-ignore
|
||||||
|
if self.eval_batches_file:
|
||||||
|
self.eval_batches = self._load_filter_eval_batches()
|
||||||
|
|
||||||
|
logger.info(str(self))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return len(self._index)
|
||||||
|
|
||||||
|
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||||
|
"""
|
||||||
|
Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair
|
||||||
|
"""
|
||||||
|
return self._get_item(frame_idx, True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta(self):
|
||||||
|
"""
|
||||||
|
Allows accessing metadata only without loading blobs using `dataset.meta[idx]`.
|
||||||
|
Requires box_crop==False, since in that case, cameras cannot be adjusted
|
||||||
|
without loading masks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FrameData objects with blob fields like `image_rgb` set to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if dataset.box_crop is set.
|
||||||
|
"""
|
||||||
|
return SqlIndexDataset._MetadataAccessor(self)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _MetadataAccessor:
|
||||||
|
dataset: "SqlIndexDataset"
|
||||||
|
|
||||||
|
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||||
|
return self.dataset._get_item(frame_idx, False)
|
||||||
|
|
||||||
|
def _get_item(
|
||||||
|
self, frame_idx: Union[int, Tuple[str, int]], load_blobs: bool = True
|
||||||
|
) -> FrameData:
|
||||||
|
if isinstance(frame_idx, int):
|
||||||
|
if frame_idx >= len(self._index):
|
||||||
|
raise IndexError(f"index {frame_idx} out of range {len(self._index)}")
|
||||||
|
|
||||||
|
seq, frame = self._index.index[frame_idx]
|
||||||
|
else:
|
||||||
|
seq, frame, *rest = frame_idx
|
||||||
|
if (seq, frame) not in self._index.index:
|
||||||
|
raise IndexError(
|
||||||
|
f"Sequence-frame index {frame_idx} not found; was it filtered out?"
|
||||||
|
)
|
||||||
|
|
||||||
|
if rest and rest[0] != self._index.loc[(seq, frame), "_image_path"]:
|
||||||
|
raise IndexError(f"Non-matching image path in {frame_idx}.")
|
||||||
|
|
||||||
|
stmt = sa.select(self.frame_annotations_type).where(
|
||||||
|
self.frame_annotations_type.sequence_name == seq,
|
||||||
|
self.frame_annotations_type.frame_number
|
||||||
|
== int(frame), # cast from np.int64
|
||||||
|
)
|
||||||
|
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
||||||
|
SqlSequenceAnnotation.sequence_name == seq
|
||||||
|
)
|
||||||
|
with Session(self._sql_engine) as session:
|
||||||
|
entry = session.scalars(stmt).one()
|
||||||
|
seq_metadata = session.scalars(seq_stmt).one()
|
||||||
|
|
||||||
|
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||||
|
|
||||||
|
frame_data = self.frame_data_builder.build(
|
||||||
|
entry, seq_metadata, load_blobs=load_blobs
|
||||||
|
)
|
||||||
|
|
||||||
|
# The rest of the fields are optional
|
||||||
|
frame_data.frame_type = self._get_frame_type(entry)
|
||||||
|
return frame_data
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||||
|
|
||||||
|
def sequence_names(self) -> Iterable[str]:
|
||||||
|
"""Returns an iterator over sequence names in the dataset."""
|
||||||
|
return self._index.index.unique("sequence_name")
|
||||||
|
|
||||||
|
# override
|
||||||
|
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||||
|
stmt = sa.select(
|
||||||
|
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
||||||
|
).where( # we limit results to sequences that have frames after all filters
|
||||||
|
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
||||||
|
)
|
||||||
|
with self._sql_engine.connect() as connection:
|
||||||
|
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||||
|
|
||||||
|
return cat_to_seqs.groupby("category")["sequence_name"].apply(list).to_dict()
|
||||||
|
|
||||||
|
# override
|
||||||
|
def get_frame_numbers_and_timestamps(
|
||||||
|
self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
|
||||||
|
) -> List[Tuple[int, float]]:
|
||||||
|
"""
|
||||||
|
Implements the DatasetBase method.
|
||||||
|
|
||||||
|
NOTE: Avoid this function as there are more efficient alternatives such as
|
||||||
|
querying `dataset[idx]` directly or getting all sequence frames with
|
||||||
|
`sequence_[frames|indices]_in_order`.
|
||||||
|
|
||||||
|
Return the index and timestamp in their videos of the frames whose
|
||||||
|
indices are given in `idxs`. They need to belong to the same sequence!
|
||||||
|
If timestamps are absent, they are replaced with zeros.
|
||||||
|
This is used for letting SceneBatchSampler identify consecutive
|
||||||
|
frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idxs: a sequence int frame index in the dataset (it can be a slice)
|
||||||
|
subset_filter: must remain None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of tuples of
|
||||||
|
- frame index in video
|
||||||
|
- timestamp of frame in video, coalesced with 0s
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if idxs belong to more than one sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if subset_filter is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Subset filters are not supported in SQL Dataset. "
|
||||||
|
"We encourage creating a dataset per subset."
|
||||||
|
)
|
||||||
|
|
||||||
|
index_slice, _ = self._get_frame_no_coalesced_ts_by_row_indices(idxs)
|
||||||
|
# alternatively, we can use `.values.tolist()`, which may be faster
|
||||||
|
# but returns a list of lists
|
||||||
|
return list(index_slice.itertuples())
|
||||||
|
|
||||||
|
# override
|
||||||
|
def sequence_frames_in_order(
|
||||||
|
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||||
|
) -> Iterator[Tuple[float, int, int]]:
|
||||||
|
"""
|
||||||
|
Overrides the default DatasetBase implementation (we don’t use `_seq_to_idx`).
|
||||||
|
Returns an iterator over the frame indices in a given sequence.
|
||||||
|
We attempt to first sort by timestamp (if they are available),
|
||||||
|
then by frame number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_name: the name of the sequence.
|
||||||
|
subset_filter: subset names to filter to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||||
|
where `frame_no` is the index within the sequence, and
|
||||||
|
`dataset_idx` is the index within the dataset.
|
||||||
|
`None` timestamps are replaced with 0s.
|
||||||
|
"""
|
||||||
|
# TODO: implement sort_timestamp_first? (which would matter if the orders
|
||||||
|
# of frame numbers and timestamps are different)
|
||||||
|
rows = self._index.index.get_loc(seq_name)
|
||||||
|
if isinstance(rows, slice):
|
||||||
|
assert rows.stop is not None, "Unexpected result from pandas"
|
||||||
|
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||||
|
else:
|
||||||
|
rows = np.where(rows)[0]
|
||||||
|
|
||||||
|
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||||
|
rows, seq_name, subset_filter
|
||||||
|
)
|
||||||
|
index_slice["idx"] = idx
|
||||||
|
|
||||||
|
yield from index_slice.itertuples(index=False)
|
||||||
|
|
||||||
|
# override
|
||||||
|
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||||
|
"""
|
||||||
|
This class does not support eval batches with ordinal indices. You can pass
|
||||||
|
eval_batches as a batch_sampler to a data_loader since the dataset supports
|
||||||
|
`dataset[seq_name, frame_no]` indexing.
|
||||||
|
"""
|
||||||
|
return self.eval_batches
|
||||||
|
|
||||||
|
# override
|
||||||
|
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
|
||||||
|
raise ValueError("Not supported! Preprocess the data by merging them instead.")
|
||||||
|
|
||||||
|
# override
|
||||||
|
@property
|
||||||
|
def frame_data_type(self) -> Type[FrameData]:
|
||||||
|
return self.frame_data_builder.frame_data_type
|
||||||
|
|
||||||
|
def is_filtered(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns `True` in case the dataset has been filtered and thus some frame
|
||||||
|
annotations stored on the disk might be missing in the dataset object.
|
||||||
|
Does not account for subsets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
is_filtered: `True` if the dataset has been filtered, else `False`.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.remove_empty_masks
|
||||||
|
or self.limit_to > 0
|
||||||
|
or self.limit_sequences_to > 0
|
||||||
|
or len(self.pick_sequences) > 0
|
||||||
|
or len(self.exclude_sequences) > 0
|
||||||
|
or len(self.pick_categories) > 0
|
||||||
|
or self.n_frames_per_sequence > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||||
|
# maximum possible query: WHERE category IN 'self.pick_categories'
|
||||||
|
# AND sequence_name IN 'self.pick_sequences'
|
||||||
|
# AND sequence_name NOT IN 'self.exclude_sequences'
|
||||||
|
# LIMIT 'self.limit_sequence_to'
|
||||||
|
|
||||||
|
stmt = sa.select(SqlSequenceAnnotation.sequence_name)
|
||||||
|
|
||||||
|
where_conditions = [
|
||||||
|
*self._get_category_filters(),
|
||||||
|
*self._get_pick_filters(),
|
||||||
|
*self._get_exclude_filters(),
|
||||||
|
]
|
||||||
|
if where_conditions:
|
||||||
|
stmt = stmt.where(*where_conditions)
|
||||||
|
|
||||||
|
if self.limit_sequences_to > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Limiting dataset to first {self.limit_sequences_to} sequences"
|
||||||
|
)
|
||||||
|
# NOTE: ROWID is SQLite-specific
|
||||||
|
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
|
||||||
|
|
||||||
|
if not where_conditions and self.limit_sequences_to <= 0:
|
||||||
|
# we will not need to filter by sequences
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._sql_engine.connect() as connection:
|
||||||
|
sequences = pd.read_sql_query(stmt, connection)["sequence_name"]
|
||||||
|
logger.info("... retained %d sequences" % len(sequences))
|
||||||
|
|
||||||
|
return sequences
|
||||||
|
|
||||||
|
def _get_category_filters(self) -> List[sa.ColumnElement]:
|
||||||
|
if not self.pick_categories:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||||
|
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
||||||
|
|
||||||
|
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||||
|
if not self.pick_sequences:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||||
|
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
||||||
|
|
||||||
|
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||||
|
if not self.exclude_sequences:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||||
|
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
||||||
|
|
||||||
|
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||||
|
assert self.subsets is not None
|
||||||
|
with open(subset_lists_path, "r") as f:
|
||||||
|
subset_to_seq_frame = json.load(f)
|
||||||
|
|
||||||
|
seq_frame_list = sum(
|
||||||
|
(
|
||||||
|
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||||
|
for subset in self.subsets
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
index = pd.DataFrame(
|
||||||
|
seq_frame_list,
|
||||||
|
columns=["sequence_name", "frame_number", "_image_path", "subset"],
|
||||||
|
)
|
||||||
|
return index
|
||||||
|
|
||||||
|
def _load_subsets_from_sql(self, subset_lists_path: str) -> pd.DataFrame:
|
||||||
|
subsets = self.subsets
|
||||||
|
assert subsets is not None
|
||||||
|
# we need a new engine since we store the subsets in a separate DB
|
||||||
|
engine = sa.create_engine(f"sqlite:///{subset_lists_path}")
|
||||||
|
table = sa.Table(_SET_LISTS_TABLE, sa.MetaData(), autoload_with=engine)
|
||||||
|
stmt = sa.select(table).where(table.c.subset.in_(subsets))
|
||||||
|
with engine.connect() as connection:
|
||||||
|
index = pd.read_sql(stmt, connection)
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
def _build_index_from_subset_lists(
|
||||||
|
self, sequences: Optional[pd.Series]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
if not self.subset_lists_file:
|
||||||
|
raise ValueError("Requested subsets but subset_lists_file not given")
|
||||||
|
|
||||||
|
logger.info(f"Loading subset lists from {self.subset_lists_file}.")
|
||||||
|
|
||||||
|
subset_lists_path = self._local_path(self.subset_lists_file)
|
||||||
|
if subset_lists_path.lower().endswith(".json"):
|
||||||
|
index = self._load_subsets_from_json(subset_lists_path)
|
||||||
|
else:
|
||||||
|
index = self._load_subsets_from_sql(subset_lists_path)
|
||||||
|
index = index.set_index(["sequence_name", "frame_number"])
|
||||||
|
logger.info(f" -> loaded {len(index)} samples of {self.subsets}.")
|
||||||
|
|
||||||
|
if sequences is not None:
|
||||||
|
logger.info("Applying filtered sequences.")
|
||||||
|
sequence_values = index.index.get_level_values("sequence_name")
|
||||||
|
index = index.loc[sequence_values.isin(sequences)]
|
||||||
|
logger.info(f" -> retained {len(index)} samples.")
|
||||||
|
|
||||||
|
pick_frames_criteria = []
|
||||||
|
if self.remove_empty_masks:
|
||||||
|
logger.info("Culling samples with empty masks.")
|
||||||
|
|
||||||
|
if len(index) > self.remove_empty_masks_poll_whole_table_threshold:
|
||||||
|
# APPROACH 1: find empty masks and drop indices.
|
||||||
|
# dev load: 17s / 15 s (3.1M / 500K)
|
||||||
|
stmt = sa.select(
|
||||||
|
self.frame_annotations_type.sequence_name,
|
||||||
|
self.frame_annotations_type.frame_number,
|
||||||
|
).where(self.frame_annotations_type._mask_mass == 0)
|
||||||
|
with Session(self._sql_engine) as session:
|
||||||
|
to_remove = session.execute(stmt).all()
|
||||||
|
|
||||||
|
# Pandas uses np.int64 for integer types, so we have to case
|
||||||
|
# we might want to read it to pandas DataFrame directly to avoid the loop
|
||||||
|
to_remove = [(seq, np.int64(fr)) for seq, fr in to_remove]
|
||||||
|
index.drop(to_remove, errors="ignore", inplace=True)
|
||||||
|
else:
|
||||||
|
# APPROACH 3: load index into a temp table and join with annotations
|
||||||
|
# dev load: 94 s / 23 s (3.1M / 500K)
|
||||||
|
pick_frames_criteria.append(
|
||||||
|
sa.or_(
|
||||||
|
self.frame_annotations_type._mask_mass.is_(None),
|
||||||
|
self.frame_annotations_type._mask_mass != 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pick_frames_sql_clause:
|
||||||
|
logger.info("Applying the custom SQL clause.")
|
||||||
|
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
|
||||||
|
|
||||||
|
if pick_frames_criteria:
|
||||||
|
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
||||||
|
|
||||||
|
logger.info(f" -> retained {len(index)} samples.")
|
||||||
|
|
||||||
|
if self.limit_to > 0:
|
||||||
|
logger.info(f"Limiting dataset to first {self.limit_to} frames")
|
||||||
|
index = index.sort_index().iloc[: self.limit_to]
|
||||||
|
|
||||||
|
return index.reset_index()
|
||||||
|
|
||||||
|
def _pick_frames_by_criteria(self, index: pd.DataFrame, criteria) -> pd.DataFrame:
|
||||||
|
IndexTable = self._get_temp_index_table_instance()
|
||||||
|
with self._sql_engine.connect() as connection:
|
||||||
|
IndexTable.create(connection)
|
||||||
|
# we don’t let pandas’s `to_sql` create the table automatically as
|
||||||
|
# the table would be permanent, so we create it and append with pandas
|
||||||
|
n_rows = index.to_sql(IndexTable.name, connection, if_exists="append")
|
||||||
|
assert n_rows == len(index)
|
||||||
|
sa_type = self.frame_annotations_type
|
||||||
|
stmt = (
|
||||||
|
sa.select(IndexTable)
|
||||||
|
.select_from(
|
||||||
|
IndexTable.join(
|
||||||
|
self.frame_annotations_type,
|
||||||
|
sa.and_(
|
||||||
|
sa_type.sequence_name == IndexTable.c.sequence_name,
|
||||||
|
sa_type.frame_number == IndexTable.c.frame_number,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.where(*criteria)
|
||||||
|
)
|
||||||
|
return pd.read_sql_query(stmt, connection).set_index(
|
||||||
|
["sequence_name", "frame_number"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_index_from_db(self, sequences: Optional[pd.Series]):
|
||||||
|
logger.info("Loading sequcence-frame index from the database")
|
||||||
|
stmt = sa.select(
|
||||||
|
self.frame_annotations_type.sequence_name,
|
||||||
|
self.frame_annotations_type.frame_number,
|
||||||
|
self.frame_annotations_type._image_path,
|
||||||
|
sa.null().label("subset"),
|
||||||
|
)
|
||||||
|
where_conditions = []
|
||||||
|
if sequences is not None:
|
||||||
|
logger.info(" applying filtered sequences")
|
||||||
|
where_conditions.append(
|
||||||
|
self.frame_annotations_type.sequence_name.in_(sequences.tolist())
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.remove_empty_masks:
|
||||||
|
logger.info(" excluding samples with empty masks")
|
||||||
|
where_conditions.append(
|
||||||
|
sa.or_(
|
||||||
|
self.frame_annotations_type._mask_mass.is_(None),
|
||||||
|
self.frame_annotations_type._mask_mass != 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pick_frames_sql_clause:
|
||||||
|
logger.info(" applying custom SQL clause")
|
||||||
|
where_conditions.append(sa.text(self.pick_frames_sql_clause))
|
||||||
|
|
||||||
|
if where_conditions:
|
||||||
|
stmt = stmt.where(*where_conditions)
|
||||||
|
|
||||||
|
if self.limit_to > 0:
|
||||||
|
logger.info(f"Limiting dataset to first {self.limit_to} frames")
|
||||||
|
stmt = stmt.order_by(
|
||||||
|
self.frame_annotations_type.sequence_name,
|
||||||
|
self.frame_annotations_type.frame_number,
|
||||||
|
).limit(self.limit_to)
|
||||||
|
|
||||||
|
with self._sql_engine.connect() as connection:
|
||||||
|
index = pd.read_sql_query(stmt, connection)
|
||||||
|
|
||||||
|
logger.info(f" -> loaded {len(index)} samples.")
|
||||||
|
return index
|
||||||
|
|
||||||
|
def _sort_index_(self, index):
|
||||||
|
logger.info("Sorting the index by sequence and frame number.")
|
||||||
|
index.sort_values(["sequence_name", "frame_number"], inplace=True)
|
||||||
|
logger.info(" -> Done.")
|
||||||
|
|
||||||
|
def _load_filter_eval_batches(self):
|
||||||
|
assert self.eval_batches_file
|
||||||
|
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||||
|
|
||||||
|
if not os.path.isfile(self.eval_batches_file):
|
||||||
|
# The batch indices file does not exist.
|
||||||
|
# Most probably the user has not specified the root folder.
|
||||||
|
raise ValueError(
|
||||||
|
f"Looking for dataset json file in {self.eval_batches_file}. "
|
||||||
|
+ "Please specify a correct dataset_root folder."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(self.eval_batches_file, "r") as f:
|
||||||
|
eval_batches = json.load(f)
|
||||||
|
|
||||||
|
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||||
|
pick_sequences = set(self.pick_sequences)
|
||||||
|
if self.pick_categories:
|
||||||
|
cat_to_seq = self.category_to_sequence_names()
|
||||||
|
pick_sequences.update(
|
||||||
|
seq for cat in self.pick_categories for seq in cat_to_seq[cat]
|
||||||
|
)
|
||||||
|
|
||||||
|
if pick_sequences:
|
||||||
|
old_len = len(eval_batches)
|
||||||
|
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
|
||||||
|
logger.warn(
|
||||||
|
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.exclude_sequences:
|
||||||
|
old_len = len(eval_batches)
|
||||||
|
exclude_sequences = set(self.exclude_sequences)
|
||||||
|
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
|
||||||
|
logger.warn(
|
||||||
|
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return eval_batches
|
||||||
|
|
||||||
|
def _stratified_sample_index(self, index):
|
||||||
|
# NOTE this stratified sampling can be done more efficiently in
|
||||||
|
# the no-subset case above if it is added to the SQL query.
|
||||||
|
# We keep this generic implementation since no-subset case is uncommon
|
||||||
|
index = index.groupby("sequence_name", group_keys=False).apply(
|
||||||
|
lambda seq_frames: seq_frames.sample(
|
||||||
|
min(len(seq_frames), self.n_frames_per_sequence),
|
||||||
|
random_state=(
|
||||||
|
_seq_name_to_seed(seq_frames.iloc[0]["sequence_name"]) + self.seed
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f" -> retained {len(index)} samples aster stratified sampling.")
|
||||||
|
return index
|
||||||
|
|
||||||
|
def _get_frame_type(self, entry: SqlFrameAnnotation) -> Optional[str]:
|
||||||
|
return self._index.loc[(entry.sequence_name, entry.frame_number), "subset"]
|
||||||
|
|
||||||
|
def _get_frame_no_coalesced_ts_by_row_indices(
|
||||||
|
self,
|
||||||
|
idxs: Sequence[int],
|
||||||
|
seq_name: Optional[str] = None,
|
||||||
|
subset_filter: Union[Sequence[str], str, None] = None,
|
||||||
|
) -> Tuple[pd.DataFrame, Sequence[int]]:
|
||||||
|
"""
|
||||||
|
Loads timestamps for given index rows belonging to the same sequence.
|
||||||
|
If seq_name is known, it speeds up the computation.
|
||||||
|
Raises ValueError if `idxs` do not all belong to a single sequences .
|
||||||
|
"""
|
||||||
|
index_slice = self._index.iloc[idxs]
|
||||||
|
if subset_filter is not None:
|
||||||
|
if isinstance(subset_filter, str):
|
||||||
|
subset_filter = [subset_filter]
|
||||||
|
indicator = index_slice["subset"].isin(subset_filter)
|
||||||
|
index_slice = index_slice.loc[indicator]
|
||||||
|
idxs = [i for i, isin in zip(idxs, indicator) if isin]
|
||||||
|
|
||||||
|
frames = index_slice.index.get_level_values("frame_number").tolist()
|
||||||
|
if seq_name is None:
|
||||||
|
seq_name_list = index_slice.index.get_level_values("sequence_name").tolist()
|
||||||
|
seq_name_set = set(seq_name_list)
|
||||||
|
if len(seq_name_set) > 1:
|
||||||
|
raise ValueError("Given indices belong to more than one sequence.")
|
||||||
|
elif len(seq_name_set) == 1:
|
||||||
|
seq_name = seq_name_list[0]
|
||||||
|
|
||||||
|
coalesced_ts = sa.sql.functions.coalesce(
|
||||||
|
self.frame_annotations_type.frame_timestamp, 0
|
||||||
|
)
|
||||||
|
stmt = sa.select(
|
||||||
|
coalesced_ts.label("frame_timestamp"),
|
||||||
|
self.frame_annotations_type.frame_number,
|
||||||
|
).where(
|
||||||
|
self.frame_annotations_type.sequence_name == seq_name,
|
||||||
|
self.frame_annotations_type.frame_number.in_(frames),
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._sql_engine.connect() as connection:
|
||||||
|
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||||
|
|
||||||
|
if len(frame_no_ts) != len(index_slice):
|
||||||
|
raise ValueError(
|
||||||
|
"Not all indices are found in the database; "
|
||||||
|
"do they belong to more than one sequence?"
|
||||||
|
)
|
||||||
|
|
||||||
|
return frame_no_ts, idxs
|
||||||
|
|
||||||
|
def _local_path(self, path: str) -> str:
|
||||||
|
if self.path_manager is None:
|
||||||
|
return path
|
||||||
|
return self.path_manager.get_local_path(path)
|
||||||
|
|
||||||
|
def _get_temp_index_table_instance(self, table_name: str = "__index"):
|
||||||
|
CachedTable = self.frame_annotations_type.metadata.tables.get(table_name)
|
||||||
|
if CachedTable is not None: # table definition is not idempotent
|
||||||
|
return CachedTable
|
||||||
|
|
||||||
|
return sa.Table(
|
||||||
|
table_name,
|
||||||
|
self.frame_annotations_type.metadata,
|
||||||
|
sa.Column("sequence_name", sa.String, primary_key=True),
|
||||||
|
sa.Column("frame_number", sa.Integer, primary_key=True),
|
||||||
|
sa.Column("_image_path", sa.String),
|
||||||
|
sa.Column("subset", sa.String),
|
||||||
|
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _seq_name_to_seed(seq_name) -> int:
|
||||||
|
"""Generates numbers in [0, 2 ** 28)"""
|
||||||
|
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_as_tensor(data, dtype):
|
||||||
|
return torch.tensor(data, dtype=dtype) if data is not None else None
|
424
pytorch3d/implicitron/dataset/sql_dataset_provider.py
Normal file
424
pytorch3d/implicitron/dataset/sql_dataset_provider.py
Normal file
@ -0,0 +1,424 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
expand_args_fields,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .sql_dataset import SqlIndexDataset
|
||||||
|
|
||||||
|
|
||||||
|
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
|
||||||
|
|
||||||
|
# _NEED_CONTROL is a list of those elements of SqlIndexDataset which
|
||||||
|
# are not directly specified for it in the config but come from the
|
||||||
|
# DatasetMapProvider.
|
||||||
|
_NEED_CONTROL: Tuple[str, ...] = (
|
||||||
|
"path_manager",
|
||||||
|
"subsets",
|
||||||
|
"sqlite_metadata_file",
|
||||||
|
"subset_lists_file",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||||
|
"""
|
||||||
|
Generates the training, validation, and testing dataset objects for
|
||||||
|
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||||
|
|
||||||
|
The dataset is organized in the filesystem as follows::
|
||||||
|
|
||||||
|
self.dataset_root
|
||||||
|
├── <possible/partition/0>
|
||||||
|
│ ├── <sequence_name_0>
|
||||||
|
│ │ ├── depth_masks
|
||||||
|
│ │ ├── depths
|
||||||
|
│ │ ├── images
|
||||||
|
│ │ ├── masks
|
||||||
|
│ │ └── pointcloud.ply
|
||||||
|
│ ├── <sequence_name_1>
|
||||||
|
│ │ ├── depth_masks
|
||||||
|
│ │ ├── depths
|
||||||
|
│ │ ├── images
|
||||||
|
│ │ ├── masks
|
||||||
|
│ │ └── pointcloud.ply
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── <sequence_name_N>
|
||||||
|
│ ├── set_lists
|
||||||
|
│ ├── <subset_base_name_0>.json
|
||||||
|
│ ├── <subset_base_name_1>.json
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── <subset_base_name_2>.json
|
||||||
|
│ ├── eval_batches
|
||||||
|
│ │ ├── <eval_batches_base_name_0>.json
|
||||||
|
│ │ ├── <eval_batches_base_name_1>.json
|
||||||
|
│ │ ├── ...
|
||||||
|
│ │ ├── <eval_batches_base_name_M>.json
|
||||||
|
│ ├── frame_annotations.jgz
|
||||||
|
│ ├── sequence_annotations.jgz
|
||||||
|
├── <possible/partition/1>
|
||||||
|
├── ...
|
||||||
|
├── <possible/partition/K>
|
||||||
|
├── set_lists
|
||||||
|
├── <subset_base_name_0>.sqlite
|
||||||
|
├── <subset_base_name_1>.sqlite
|
||||||
|
├── ...
|
||||||
|
├── <subset_base_name_2>.sqlite
|
||||||
|
├── eval_batches
|
||||||
|
│ ├── <eval_batches_base_name_0>.json
|
||||||
|
│ ├── <eval_batches_base_name_1>.json
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── <eval_batches_base_name_M>.json
|
||||||
|
|
||||||
|
The dataset contains sequences named `<sequence_name_i>` that may be partitioned by
|
||||||
|
directories such as `<possible/partition/0>` e.g. representing categories but they
|
||||||
|
can also be stored in a flat structure. Each sequence folder contains the list of
|
||||||
|
sequence images, depth maps, foreground masks, and valid-depth masks
|
||||||
|
`images`, `depths`, `masks`, and `depth_masks` respectively. Furthermore,
|
||||||
|
`set_lists/` dirtectories (with partitions or global) store json or sqlite files
|
||||||
|
`<subset_base_name_l>.<ext>`, each describing a certain sequence subset.
|
||||||
|
These subset path conventions are not hard-coded and arbitrary relative path can be
|
||||||
|
specified by setting `self.subset_lists_path` to the relative path w.r.t.
|
||||||
|
dataset root.
|
||||||
|
|
||||||
|
Each `<subset_base_name_l>.json` file contains the following dictionary::
|
||||||
|
|
||||||
|
{
|
||||||
|
"train": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"val": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"test": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
defining the list of frames (identified with their `sequence_name` and
|
||||||
|
`frame_number`) in the "train", "val", and "test" subsets of the dataset. In case of
|
||||||
|
SQLite format, `<subset_base_name_l>.sqlite` contains a table with the header::
|
||||||
|
|
||||||
|
| sequence_name | frame_number | image_path | subset |
|
||||||
|
|
||||||
|
Note that `frame_number` can be obtained only from the metadata and
|
||||||
|
does not necesarrily correspond to the numeric suffix of the corresponding image
|
||||||
|
file name (e.g. a file `<partition_0>/<sequence_name_0>/images/frame00005.jpg` can
|
||||||
|
have its frame number set to `20`, not 5).
|
||||||
|
|
||||||
|
Each `<eval_batches_base_name_M>.json` file contains a list of evaluation examples
|
||||||
|
in the following form::
|
||||||
|
|
||||||
|
[
|
||||||
|
[ # batch 1
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
[ # batch 2
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
Note that the evaluation examples always come from the `"test"` subset of the dataset.
|
||||||
|
(test frames can repeat across batches). The batches can contain single element,
|
||||||
|
which is typical in case of regular radiance field fitting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subset_lists_path: The relative path to the dataset subset definition.
|
||||||
|
For CO3D, these include e.g. "skateboard/set_lists/set_lists_manyview_dev_0.json".
|
||||||
|
By default (None), dataset is not partitioned to subsets (in that case, setting
|
||||||
|
`ignore_subsets` will speed up construction)
|
||||||
|
dataset_root: The root folder of the dataset.
|
||||||
|
metadata_basename: name of the SQL metadata file in dataset_root;
|
||||||
|
not expected to be changed by users
|
||||||
|
test_on_train: Construct validation and test datasets from
|
||||||
|
the training subset; note that in practice, in this
|
||||||
|
case all subset dataset objects will be same
|
||||||
|
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||||
|
ignore_subsets: Don’t filter by subsets in the dataset; note that in this
|
||||||
|
case all subset datasets will be same
|
||||||
|
eval_batch_num_training_frames: Add a certain number of training frames to each
|
||||||
|
eval batch. Useful for evaluating models that require
|
||||||
|
source views as input (e.g. NeRF-WCE / PixelNeRF).
|
||||||
|
dataset_args: Specifies additional arguments to the
|
||||||
|
JsonIndexDataset constructor call.
|
||||||
|
path_manager_factory: (Optional) An object that generates an instance of
|
||||||
|
PathManager that can translate provided file paths.
|
||||||
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
category: Optional[str] = None
|
||||||
|
subset_list_name: Optional[str] = None # TODO: docs
|
||||||
|
# OR
|
||||||
|
subset_lists_path: Optional[str] = None
|
||||||
|
eval_batches_path: Optional[str] = None
|
||||||
|
|
||||||
|
dataset_root: str = _CO3D_SQL_DATASET_ROOT
|
||||||
|
metadata_basename: str = "metadata.sqlite"
|
||||||
|
|
||||||
|
test_on_train: bool = False
|
||||||
|
only_test_set: bool = False
|
||||||
|
ignore_subsets: bool = False
|
||||||
|
train_subsets: Tuple[str, ...] = ("train",)
|
||||||
|
val_subsets: Tuple[str, ...] = ("val",)
|
||||||
|
test_subsets: Tuple[str, ...] = ("test",)
|
||||||
|
|
||||||
|
eval_batch_num_training_frames: int = 0
|
||||||
|
|
||||||
|
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||||
|
dataset_class_type: str = "SqlIndexDataset"
|
||||||
|
dataset: SqlIndexDataset
|
||||||
|
|
||||||
|
path_manager_factory: PathManagerFactory
|
||||||
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
if self.only_test_set and self.test_on_train:
|
||||||
|
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||||
|
|
||||||
|
if self.ignore_subsets and not self.only_test_set:
|
||||||
|
self.test_on_train = True # no point in loading same data 3 times
|
||||||
|
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
|
||||||
|
sqlite_metadata_file = os.path.join(self.dataset_root, self.metadata_basename)
|
||||||
|
sqlite_metadata_file = _local_path(path_manager, sqlite_metadata_file)
|
||||||
|
|
||||||
|
if not os.path.isfile(sqlite_metadata_file):
|
||||||
|
# The sqlite_metadata_file does not exist.
|
||||||
|
# Most probably the user has not specified the root folder.
|
||||||
|
raise ValueError(
|
||||||
|
f"Looking for frame annotations in {sqlite_metadata_file}."
|
||||||
|
+ " Please specify a correct dataset_root folder."
|
||||||
|
+ " Note: By default the root folder is taken from the"
|
||||||
|
+ " CO3D_SQL_DATASET_ROOT environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.subset_lists_path and self.subset_list_name:
|
||||||
|
raise ValueError(
|
||||||
|
"subset_lists_path and subset_list_name cannot be both set"
|
||||||
|
)
|
||||||
|
|
||||||
|
subset_lists_file = self._get_lists_file("set_lists")
|
||||||
|
|
||||||
|
# setup the common dataset arguments
|
||||||
|
common_dataset_kwargs = {
|
||||||
|
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
|
||||||
|
"sqlite_metadata_file": sqlite_metadata_file,
|
||||||
|
"dataset_root": self.dataset_root,
|
||||||
|
"subset_lists_file": subset_lists_file,
|
||||||
|
"path_manager": path_manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.category:
|
||||||
|
logger.info(f"Forcing category filter in the datasets to {self.category}")
|
||||||
|
common_dataset_kwargs["pick_categories"] = self.category.split(",")
|
||||||
|
|
||||||
|
# get the used dataset type
|
||||||
|
dataset_type: Type[SqlIndexDataset] = registry.get(
|
||||||
|
SqlIndexDataset, self.dataset_class_type
|
||||||
|
)
|
||||||
|
expand_args_fields(dataset_type)
|
||||||
|
|
||||||
|
if subset_lists_file is not None and not os.path.isfile(subset_lists_file):
|
||||||
|
available_subsets = self._get_available_subsets(
|
||||||
|
OmegaConf.to_object(common_dataset_kwargs["pick_categories"])
|
||||||
|
)
|
||||||
|
msg = f"Cannot find subset list file {self.subset_lists_path}."
|
||||||
|
if available_subsets:
|
||||||
|
msg += f" Some of the available subsets: {str(available_subsets)}."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
train_dataset = None
|
||||||
|
val_dataset = None
|
||||||
|
if not self.only_test_set:
|
||||||
|
# load the training set
|
||||||
|
logger.debug("Constructing train dataset.")
|
||||||
|
train_dataset = dataset_type(
|
||||||
|
**common_dataset_kwargs, subsets=self._get_subsets(self.train_subsets)
|
||||||
|
)
|
||||||
|
logger.info(f"Train dataset: {str(train_dataset)}")
|
||||||
|
|
||||||
|
if self.test_on_train:
|
||||||
|
assert train_dataset is not None
|
||||||
|
val_dataset = test_dataset = train_dataset
|
||||||
|
else:
|
||||||
|
# load the val and test sets
|
||||||
|
if not self.only_test_set:
|
||||||
|
# NOTE: this is always loaded in JsonProviderV2
|
||||||
|
logger.debug("Extracting val dataset.")
|
||||||
|
val_dataset = dataset_type(
|
||||||
|
**common_dataset_kwargs, subsets=self._get_subsets(self.val_subsets)
|
||||||
|
)
|
||||||
|
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||||
|
|
||||||
|
logger.debug("Extracting test dataset.")
|
||||||
|
eval_batches_file = self._get_lists_file("eval_batches")
|
||||||
|
del common_dataset_kwargs["eval_batches_file"]
|
||||||
|
test_dataset = dataset_type(
|
||||||
|
**common_dataset_kwargs,
|
||||||
|
subsets=self._get_subsets(self.test_subsets, True),
|
||||||
|
eval_batches_file=eval_batches_file,
|
||||||
|
)
|
||||||
|
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
eval_batches_file is not None
|
||||||
|
and self.eval_batch_num_training_frames > 0
|
||||||
|
):
|
||||||
|
self._extend_eval_batches(test_dataset)
|
||||||
|
|
||||||
|
self._dataset_map = DatasetMap(
|
||||||
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_subsets(self, subsets, is_eval: bool = False):
|
||||||
|
if self.ignore_subsets:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if is_eval and self.eval_batch_num_training_frames > 0:
|
||||||
|
# we will need to have training frames for extended batches
|
||||||
|
return list(subsets) + list(self.train_subsets)
|
||||||
|
|
||||||
|
return subsets
|
||||||
|
|
||||||
|
def _extend_eval_batches(self, test_dataset: SqlIndexDataset) -> None:
|
||||||
|
rng = np.random.default_rng(seed=0)
|
||||||
|
eval_batches = test_dataset.get_eval_batches()
|
||||||
|
if eval_batches is None:
|
||||||
|
raise ValueError("Eval batches were not loaded!")
|
||||||
|
|
||||||
|
for batch in eval_batches:
|
||||||
|
sequence = batch[0][0]
|
||||||
|
seq_frames = list(
|
||||||
|
test_dataset.sequence_frames_in_order(sequence, self.train_subsets)
|
||||||
|
)
|
||||||
|
idx_to_add = rng.permutation(len(seq_frames))[
|
||||||
|
: self.eval_batch_num_training_frames
|
||||||
|
]
|
||||||
|
batch.extend((sequence, seq_frames[a][1]) for a in idx_to_add)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
Called by get_default_args.
|
||||||
|
Certain fields are not exposed on each dataset class
|
||||||
|
but rather are controlled by this provider class.
|
||||||
|
"""
|
||||||
|
for key in _NEED_CONTROL:
|
||||||
|
del args[key]
|
||||||
|
|
||||||
|
def create_dataset(self):
|
||||||
|
# No `dataset` member of this class is created.
|
||||||
|
# The dataset(s) live in `self.get_dataset_map`.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_dataset_map(self) -> DatasetMap:
|
||||||
|
return self._dataset_map # pyre-ignore [16]
|
||||||
|
|
||||||
|
def _get_available_subsets(self, categories: List[str]):
|
||||||
|
"""
|
||||||
|
Get the available subset names for a given category folder (if given) inside
|
||||||
|
a root dataset folder `dataset_root`.
|
||||||
|
"""
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
|
||||||
|
subsets: List[str] = []
|
||||||
|
for prefix in [""] + categories:
|
||||||
|
set_list_dir = os.path.join(self.dataset_root, prefix, "set_lists")
|
||||||
|
if not (
|
||||||
|
(path_manager is not None) and path_manager.isdir(set_list_dir)
|
||||||
|
) and not os.path.isdir(set_list_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
set_list_files = (os.listdir if path_manager is None else path_manager.ls)(
|
||||||
|
set_list_dir
|
||||||
|
)
|
||||||
|
subsets.extend(os.path.join(prefix, "set_lists", f) for f in set_list_files)
|
||||||
|
|
||||||
|
return subsets
|
||||||
|
|
||||||
|
def _get_lists_file(self, flavor: str) -> Optional[str]:
|
||||||
|
if flavor == "eval_batches":
|
||||||
|
subset_lists_path = self.eval_batches_path
|
||||||
|
else:
|
||||||
|
subset_lists_path = self.subset_lists_path
|
||||||
|
|
||||||
|
if not subset_lists_path and not self.subset_list_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
category_elem = ""
|
||||||
|
if self.category and "," not in self.category:
|
||||||
|
# if multiple categories are given, looking for global set lists
|
||||||
|
category_elem = self.category
|
||||||
|
|
||||||
|
subset_lists_path = subset_lists_path or (
|
||||||
|
os.path.join(
|
||||||
|
category_elem, f"{flavor}", f"{flavor}_{self.subset_list_name}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert subset_lists_path
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
# try absolute path first
|
||||||
|
subset_lists_file = _get_local_path_check_extensions(
|
||||||
|
subset_lists_path, path_manager
|
||||||
|
)
|
||||||
|
if subset_lists_file:
|
||||||
|
return subset_lists_file
|
||||||
|
|
||||||
|
full_path = os.path.join(self.dataset_root, subset_lists_path)
|
||||||
|
subset_lists_file = _get_local_path_check_extensions(full_path, path_manager)
|
||||||
|
|
||||||
|
if not subset_lists_file:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Subset lists path given but not found: {full_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return subset_lists_file
|
||||||
|
|
||||||
|
|
||||||
|
def _get_local_path_check_extensions(
|
||||||
|
path, path_manager, extensions=("", ".sqlite", ".json")
|
||||||
|
) -> Optional[str]:
|
||||||
|
for ext in extensions:
|
||||||
|
local = _local_path(path_manager, path + ext)
|
||||||
|
if os.path.isfile(local):
|
||||||
|
return local
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _local_path(path_manager, path: str) -> str:
|
||||||
|
if path_manager is None:
|
||||||
|
return path
|
||||||
|
return path_manager.get_local_path(path)
|
189
pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py
Normal file
189
pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||||
|
DataLoaderMap,
|
||||||
|
SceneBatchSampler,
|
||||||
|
SequenceDataLoaderMapProvider,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||||
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: we can merge it with SequenceDataLoaderMapProvider in PyTorch3D
|
||||||
|
# and support both eval_batches protocols
|
||||||
|
@registry.register
|
||||||
|
class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider):
|
||||||
|
"""
|
||||||
|
Implementation of DataLoaderMapProviderBase that may use internal eval batches for
|
||||||
|
the test dataset. In particular, if `eval_batches_relpath` is set, it loads
|
||||||
|
eval batches from that json file, otherwise test set is treated in the same way as
|
||||||
|
train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type`
|
||||||
|
are respected.
|
||||||
|
|
||||||
|
If conditioning is not required, then the batch size should
|
||||||
|
be set as 1, and most of the fields do not matter.
|
||||||
|
|
||||||
|
If conditioning is required, each batch will contain one main
|
||||||
|
frame first to predict and the, rest of the elements are for
|
||||||
|
conditioning.
|
||||||
|
|
||||||
|
If images_per_seq_options is left empty, the conditioning
|
||||||
|
frames are picked according to the conditioning type given.
|
||||||
|
This does not have regard to the order of frames in a
|
||||||
|
scene, or which frames belong to what scene.
|
||||||
|
|
||||||
|
If images_per_seq_options is given, then the conditioning types
|
||||||
|
must be SAME and the remaining fields are used.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
batch_size: The size of the batch of the data loader.
|
||||||
|
num_workers: Number of data-loading threads in each data loader.
|
||||||
|
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
|
||||||
|
an epoch is the length of the training set.
|
||||||
|
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
|
||||||
|
an epoch is the length of the validation set.
|
||||||
|
dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of
|
||||||
|
batches in a testing epoch. Or 0 to mean an epoch is the length of the test
|
||||||
|
set.
|
||||||
|
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
|
||||||
|
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
|
||||||
|
value. Empty (the default) means that we are not careful about which frames
|
||||||
|
come from which scene.
|
||||||
|
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||||
|
in the sequence. It first sorts the frames by timestimps when available,
|
||||||
|
otherwise by frame numbers, finds the connected segments within the sequence
|
||||||
|
of sufficient length, then samples a random pivot element among them and
|
||||||
|
ideally uses it as a middle of the temporal window, shifting the borders
|
||||||
|
where necessary. This strategy mitigates the bias against shorter segments
|
||||||
|
and their boundaries.
|
||||||
|
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
||||||
|
difference in frame_number of neighbouring frames when forming connected
|
||||||
|
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
||||||
|
the whole sequence is considered a segment regardless of frame numbers.
|
||||||
|
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
||||||
|
maximum difference in frame_timestamp of neighbouring frames when forming
|
||||||
|
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
||||||
|
the whole sequence is considered a segment regardless of frame timestamps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size: int = 1
|
||||||
|
num_workers: int = 0
|
||||||
|
|
||||||
|
dataset_length_train: int = 0
|
||||||
|
dataset_length_val: int = 0
|
||||||
|
dataset_length_test: int = 0
|
||||||
|
|
||||||
|
images_per_seq_options: Tuple[int, ...] = ()
|
||||||
|
sample_consecutive_frames: bool = False
|
||||||
|
consecutive_frames_max_gap: int = 0
|
||||||
|
consecutive_frames_max_gap_seconds: float = 0.1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||||
|
"""
|
||||||
|
Returns a collection of data loaders for a given collection of datasets.
|
||||||
|
"""
|
||||||
|
train = self._make_generic_data_loader(
|
||||||
|
datasets.train,
|
||||||
|
self.dataset_length_train,
|
||||||
|
datasets.train,
|
||||||
|
)
|
||||||
|
|
||||||
|
val = self._make_generic_data_loader(
|
||||||
|
datasets.val,
|
||||||
|
self.dataset_length_val,
|
||||||
|
datasets.train,
|
||||||
|
)
|
||||||
|
|
||||||
|
if datasets.test is not None and datasets.test.get_eval_batches() is not None:
|
||||||
|
test = self._make_eval_data_loader(datasets.test)
|
||||||
|
else:
|
||||||
|
test = self._make_generic_data_loader(
|
||||||
|
datasets.test,
|
||||||
|
self.dataset_length_test,
|
||||||
|
datasets.train,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DataLoaderMap(train=train, val=val, test=test)
|
||||||
|
|
||||||
|
def _make_eval_data_loader(
|
||||||
|
self,
|
||||||
|
dataset: Optional[DatasetBase],
|
||||||
|
) -> Optional[DataLoader[FrameData]]:
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=dataset.get_eval_batches(),
|
||||||
|
**self._get_data_loader_common_kwargs(dataset),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_generic_data_loader(
|
||||||
|
self,
|
||||||
|
dataset: Optional[DatasetBase],
|
||||||
|
num_batches: int,
|
||||||
|
train_dataset: Optional[DatasetBase],
|
||||||
|
) -> Optional[DataLoader[FrameData]]:
|
||||||
|
"""
|
||||||
|
Returns the dataloader for a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
train_dataset: the training dataset, used if conditioning_type==TRAIN
|
||||||
|
conditioning_type: source for padding of batches
|
||||||
|
"""
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data_loader_kwargs = self._get_data_loader_common_kwargs(dataset)
|
||||||
|
|
||||||
|
if len(self.images_per_seq_options) > 0:
|
||||||
|
# this is a typical few-view setup
|
||||||
|
# conditioning comes from the same subset since subsets are split by seqs
|
||||||
|
batch_sampler = SceneBatchSampler(
|
||||||
|
dataset,
|
||||||
|
self.batch_size,
|
||||||
|
num_batches=len(dataset) if num_batches <= 0 else num_batches,
|
||||||
|
images_per_seq_options=self.images_per_seq_options,
|
||||||
|
sample_consecutive_frames=self.sample_consecutive_frames,
|
||||||
|
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
||||||
|
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.batch_size == 1:
|
||||||
|
# this is a typical many-view setup (without conditioning)
|
||||||
|
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
|
||||||
|
|
||||||
|
# edge case: conditioning on train subset, typical for Nerformer-like many-view
|
||||||
|
# there is only one sequence in all datasets, so we condition on another subset
|
||||||
|
return self._train_loader(
|
||||||
|
dataset, train_dataset, num_batches, data_loader_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"num_workers": self.num_workers,
|
||||||
|
"collate_fn": dataset.frame_data_type.collate,
|
||||||
|
}
|
1
setup.py
1
setup.py
@ -164,6 +164,7 @@ setup(
|
|||||||
"tqdm>4.29.0",
|
"tqdm>4.29.0",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"accelerate",
|
"accelerate",
|
||||||
|
"sqlalchemy>=2.0",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
entry_points={
|
entry_points={
|
||||||
|
1
tests/implicitron/data/sql_dataset/set_lists_100.json
Normal file
1
tests/implicitron/data/sql_dataset/set_lists_100.json
Normal file
File diff suppressed because one or more lines are too long
BIN
tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite
Normal file
BIN
tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite
Normal file
Binary file not shown.
246
tests/implicitron/test_co3d_sql.py
Normal file
246
tests/implicitron/test_co3d_sql.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
|
||||||
|
SequenceDataLoaderMapProvider,
|
||||||
|
SimpleDataLoaderMapProvider,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
|
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa
|
||||||
|
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa
|
||||||
|
SqlIndexDatasetMapProvider,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import (
|
||||||
|
TrainEvalDataLoaderMapProvider,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
|
|
||||||
|
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||||
|
sh = logging.StreamHandler()
|
||||||
|
logger.addHandler(sh)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available")
|
||||||
|
class TestCo3dSqlDataSource(unittest.TestCase):
|
||||||
|
def test_no_subsets(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.ignore_subsets = True
|
||||||
|
|
||||||
|
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||||
|
dataset_args.pick_categories = ["skateboard"]
|
||||||
|
dataset_args.limit_sequences_to = 1
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
self.assertIsInstance(
|
||||||
|
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider
|
||||||
|
)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
self.assertEqual(len(data_loaders.train), 202)
|
||||||
|
for frame in data_loaders.train:
|
||||||
|
self.assertIsNone(frame.frame_type)
|
||||||
|
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_subsets(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.subset_lists_path = (
|
||||||
|
"skateboard/set_lists/set_lists_manyview_dev_0.json"
|
||||||
|
)
|
||||||
|
# this will naturally limit to one sequence (no need to limit by cat/sequence)
|
||||||
|
|
||||||
|
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||||
|
dataset_args.remove_empty_masks = True
|
||||||
|
|
||||||
|
for sampler_type in [
|
||||||
|
"SimpleDataLoaderMapProvider",
|
||||||
|
"SequenceDataLoaderMapProvider",
|
||||||
|
"TrainEvalDataLoaderMapProvider",
|
||||||
|
]:
|
||||||
|
args.data_loader_map_provider_class_type = sampler_type
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
self.assertEqual(len(data_loaders.train), 102)
|
||||||
|
self.assertEqual(len(data_loaders.val), 100)
|
||||||
|
self.assertEqual(len(data_loaders.test), 100)
|
||||||
|
for split in ["train", "val", "test"]:
|
||||||
|
for frame in data_loaders[split]:
|
||||||
|
self.assertEqual(frame.frame_type, [split])
|
||||||
|
# check loading blobs
|
||||||
|
self.assertEqual(frame.image_rgb.shape[-1], 800)
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_sql_subsets(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||||
|
|
||||||
|
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||||
|
dataset_args.remove_empty_masks = True
|
||||||
|
dataset_args.pick_categories = ["skateboard"]
|
||||||
|
|
||||||
|
for sampler_type in [
|
||||||
|
"SimpleDataLoaderMapProvider",
|
||||||
|
"SequenceDataLoaderMapProvider",
|
||||||
|
"TrainEvalDataLoaderMapProvider",
|
||||||
|
]:
|
||||||
|
args.data_loader_map_provider_class_type = sampler_type
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
self.assertEqual(len(data_loaders.train), 102)
|
||||||
|
self.assertEqual(len(data_loaders.val), 100)
|
||||||
|
self.assertEqual(len(data_loaders.test), 100)
|
||||||
|
for split in ["train", "val", "test"]:
|
||||||
|
for frame in data_loaders[split]:
|
||||||
|
self.assertEqual(frame.frame_type, [split])
|
||||||
|
self.assertEqual(
|
||||||
|
frame.image_rgb.shape[-1], 800
|
||||||
|
) # check loading blobs
|
||||||
|
break
|
||||||
|
|
||||||
|
@unittest.skip("It takes 75 seconds; skipping by default")
|
||||||
|
def test_huge_subsets(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite"
|
||||||
|
|
||||||
|
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||||
|
dataset_args.remove_empty_masks = True
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
self.assertEqual(len(data_loaders.train), 3158974)
|
||||||
|
self.assertEqual(len(data_loaders.val), 518417)
|
||||||
|
self.assertEqual(len(data_loaders.test), 518417)
|
||||||
|
for split in ["train", "val", "test"]:
|
||||||
|
for frame in data_loaders[split]:
|
||||||
|
self.assertEqual(frame.frame_type, [split])
|
||||||
|
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_broken_subsets(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.subset_lists_path = "et_non_est"
|
||||||
|
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"]
|
||||||
|
with self.assertRaises(FileNotFoundError) as err:
|
||||||
|
ImplicitronDataSource(**args)
|
||||||
|
|
||||||
|
# check the hint text
|
||||||
|
self.assertIn("Subset lists path given but not found", str(err.exception))
|
||||||
|
|
||||||
|
def test_eval_batches(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||||
|
provider_args.eval_batches_path = (
|
||||||
|
"skateboard/eval_batches/eval_batches_manyview_dev_0.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||||
|
dataset_args.remove_empty_masks = True
|
||||||
|
dataset_args.pick_categories = ["skateboard"]
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
self.assertEqual(len(data_loaders.train), 102)
|
||||||
|
self.assertEqual(len(data_loaders.val), 100)
|
||||||
|
self.assertEqual(len(data_loaders.test), 50)
|
||||||
|
for split in ["train", "val", "test"]:
|
||||||
|
for frame in data_loaders[split]:
|
||||||
|
self.assertEqual(frame.frame_type, [split])
|
||||||
|
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_eval_batches_from_subset_list_name(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.subset_list_name = "manyview_dev_0"
|
||||||
|
provider_args.category = "skateboard"
|
||||||
|
|
||||||
|
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||||
|
dataset_args.remove_empty_masks = True
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
dataset, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"])
|
||||||
|
self.assertEqual(len(data_loaders.train), 102)
|
||||||
|
self.assertEqual(len(data_loaders.val), 100)
|
||||||
|
self.assertEqual(len(data_loaders.test), 50)
|
||||||
|
for split in ["train", "val", "test"]:
|
||||||
|
for frame in data_loaders[split]:
|
||||||
|
self.assertEqual(frame.frame_type, [split])
|
||||||
|
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_frame_access(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||||
|
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||||
|
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||||
|
|
||||||
|
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||||
|
dataset_args.remove_empty_masks = True
|
||||||
|
dataset_args.pick_categories = ["skateboard"]
|
||||||
|
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args
|
||||||
|
frame_builder_args.load_point_clouds = True
|
||||||
|
frame_builder_args.box_crop = False # required for .meta
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
dataset_map, _ = data_source.get_datasets_and_dataloaders()
|
||||||
|
dataset = dataset_map["train"]
|
||||||
|
|
||||||
|
for idx in [10, ("245_26182_52130", 22)]:
|
||||||
|
example_meta = dataset.meta[idx]
|
||||||
|
example = dataset[idx]
|
||||||
|
|
||||||
|
self.assertIsNone(example_meta.image_rgb)
|
||||||
|
self.assertIsNone(example_meta.fg_probability)
|
||||||
|
self.assertIsNone(example_meta.depth_map)
|
||||||
|
self.assertIsNone(example_meta.sequence_point_cloud)
|
||||||
|
self.assertIsNotNone(example_meta.camera)
|
||||||
|
|
||||||
|
self.assertIsNotNone(example.image_rgb)
|
||||||
|
self.assertIsNotNone(example.fg_probability)
|
||||||
|
self.assertIsNotNone(example.depth_map)
|
||||||
|
self.assertIsNotNone(example.sequence_point_cloud)
|
||||||
|
self.assertIsNotNone(example.camera)
|
||||||
|
|
||||||
|
self.assertEqual(example_meta.sequence_name, example.sequence_name)
|
||||||
|
self.assertEqual(example_meta.frame_number, example.frame_number)
|
||||||
|
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
|
||||||
|
self.assertEqual(example_meta.sequence_category, example.sequence_category)
|
||||||
|
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
|
||||||
|
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
example_meta.camera.focal_length, example.camera.focal_length
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
example_meta.camera.principal_point, example.camera.principal_point
|
||||||
|
)
|
522
tests/implicitron/test_sql_dataset.py
Normal file
522
tests/implicitron/test_sql_dataset.py
Normal file
@ -0,0 +1,522 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
|
||||||
|
|
||||||
|
NO_BLOBS_KWARGS = {
|
||||||
|
"dataset_root": "",
|
||||||
|
"load_images": False,
|
||||||
|
"load_depths": False,
|
||||||
|
"load_masks": False,
|
||||||
|
"load_depth_masks": False,
|
||||||
|
"box_crop": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||||
|
sh = logging.StreamHandler()
|
||||||
|
logger.addHandler(sh)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
|
||||||
|
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
|
||||||
|
SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSqlDataset(unittest.TestCase):
|
||||||
|
def test_basic(self, sequence="cat1_seq2", frame_number=4):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
|
# check the items are consecutive
|
||||||
|
past_sequences = set()
|
||||||
|
last_frame_number = -1
|
||||||
|
last_sequence = ""
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
|
||||||
|
if item.frame_number == 0:
|
||||||
|
self.assertNotIn(item.sequence_name, past_sequences)
|
||||||
|
past_sequences.add(item.sequence_name)
|
||||||
|
last_sequence = item.sequence_name
|
||||||
|
else:
|
||||||
|
self.assertEqual(item.sequence_name, last_sequence)
|
||||||
|
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||||
|
|
||||||
|
last_frame_number = item.frame_number
|
||||||
|
|
||||||
|
# test indexing
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[len(dataset) + 1]
|
||||||
|
|
||||||
|
# test sequence-frame indexing
|
||||||
|
item = dataset[sequence, frame_number]
|
||||||
|
self.assertEqual(item.sequence_name, sequence)
|
||||||
|
self.assertEqual(item.frame_number, frame_number)
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[sequence, 13]
|
||||||
|
|
||||||
|
def test_filter_empty_masks(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 78)
|
||||||
|
|
||||||
|
def test_pick_frames_sql_clause(self):
|
||||||
|
dataset_no_empty_masks = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the datasets are equal
|
||||||
|
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item_nem = dataset_no_empty_masks[i]
|
||||||
|
item = dataset[i]
|
||||||
|
self.assertEqual(item_nem.image_path, item.image_path)
|
||||||
|
|
||||||
|
# remove_empty_masks together with the custom criterion
|
||||||
|
dataset_ts = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
pick_frames_sql_clause="frame_timestamp < 0.15",
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(dataset_ts), 19)
|
||||||
|
|
||||||
|
def test_limit_categories(self, category="cat0"):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
pick_categories=[category],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 50)
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
self.assertEqual(dataset[i].sequence_category, category)
|
||||||
|
|
||||||
|
def test_limit_sequences(self, num_sequences=3):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
limit_sequences_to=num_sequences,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 10 * num_sequences)
|
||||||
|
|
||||||
|
def delist(sequence_name):
|
||||||
|
return sequence_name if isinstance(sequence_name, str) else sequence_name[0]
|
||||||
|
|
||||||
|
unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))}
|
||||||
|
self.assertEqual(len(unique_seqs), num_sequences)
|
||||||
|
|
||||||
|
def test_pick_exclude_sequencess(self, sequence="cat1_seq2"):
|
||||||
|
# pick sequence
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
pick_sequences=[sequence],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 10)
|
||||||
|
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||||
|
self.assertCountEqual(unique_seqs, {sequence})
|
||||||
|
|
||||||
|
item = dataset[sequence, 0]
|
||||||
|
self.assertEqual(item.sequence_name, sequence)
|
||||||
|
self.assertEqual(item.frame_number, 0)
|
||||||
|
|
||||||
|
# exclude sequence
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
exclude_sequences=[sequence],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 90)
|
||||||
|
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||||
|
self.assertNotIn(sequence, unique_seqs)
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[sequence, 0]
|
||||||
|
|
||||||
|
def test_limit_frames(self, num_frames=13):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
limit_to=num_frames,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), num_frames)
|
||||||
|
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||||
|
self.assertEqual(len(unique_seqs), 2)
|
||||||
|
|
||||||
|
# test when the limit is not binding
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
limit_to=1000,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
|
def test_limit_frames_per_sequence(self, num_frames=2):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
n_frames_per_sequence=num_frames,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), num_frames * 10)
|
||||||
|
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||||
|
self.assertEqual(len(seq_counts), 10)
|
||||||
|
self.assertCountEqual(
|
||||||
|
set(seq_counts.values()), {2}
|
||||||
|
) # all counts are num_frames
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[next(iter(seq_counts)), num_frames + 1]
|
||||||
|
|
||||||
|
# test when the limit is not binding
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
n_frames_per_sequence=13,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
|
def test_filter_medley(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
pick_categories=["cat1"],
|
||||||
|
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
|
||||||
|
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
|
||||||
|
limit_to=14, # retaining full "cat1_seq1" and 4 from "cat1_seq2"
|
||||||
|
n_frames_per_sequence=6, # cutting "cat1_seq1" to 6 frames
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
|
||||||
|
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||||
|
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
|
||||||
|
self.assertEqual(seq_counts["cat1_seq1"], 6)
|
||||||
|
self.assertEqual(seq_counts["cat1_seq2"], 4)
|
||||||
|
|
||||||
|
def test_subsets_trivial(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
limit_to=100, # force sorting
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
|
# check the items are consecutive
|
||||||
|
past_sequences = set()
|
||||||
|
last_frame_number = -1
|
||||||
|
last_sequence = ""
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
|
||||||
|
if item.frame_number == 0:
|
||||||
|
self.assertNotIn(item.sequence_name, past_sequences)
|
||||||
|
past_sequences.add(item.sequence_name)
|
||||||
|
last_sequence = item.sequence_name
|
||||||
|
else:
|
||||||
|
self.assertEqual(item.sequence_name, last_sequence)
|
||||||
|
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||||
|
|
||||||
|
last_frame_number = item.frame_number
|
||||||
|
|
||||||
|
def test_subsets_filter_empty_masks(self):
|
||||||
|
# we need to test this case as it uses quite different logic with `df.drop()`
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 78)
|
||||||
|
|
||||||
|
def test_subsets_pick_frames_sql_clause(self):
|
||||||
|
dataset_no_empty_masks = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the datasets are equal
|
||||||
|
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item_nem = dataset_no_empty_masks[i]
|
||||||
|
item = dataset[i]
|
||||||
|
self.assertEqual(item_nem.image_path, item.image_path)
|
||||||
|
|
||||||
|
# remove_empty_masks together with the custom criterion
|
||||||
|
dataset_ts = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
pick_frames_sql_clause="frame_timestamp < 0.15",
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset_ts), 19)
|
||||||
|
|
||||||
|
def test_single_subset(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 50)
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[51]
|
||||||
|
|
||||||
|
# check the items are consecutive
|
||||||
|
past_sequences = set()
|
||||||
|
last_frame_number = -1
|
||||||
|
last_sequence = ""
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
|
||||||
|
if item.frame_number < 2:
|
||||||
|
self.assertNotIn(item.sequence_name, past_sequences)
|
||||||
|
past_sequences.add(item.sequence_name)
|
||||||
|
last_sequence = item.sequence_name
|
||||||
|
else:
|
||||||
|
self.assertEqual(item.sequence_name, last_sequence)
|
||||||
|
self.assertEqual(item.frame_number, last_frame_number + 2)
|
||||||
|
|
||||||
|
last_frame_number = item.frame_number
|
||||||
|
|
||||||
|
item = dataset[last_sequence, 0]
|
||||||
|
self.assertEqual(item.sequence_name, last_sequence)
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[last_sequence, 1]
|
||||||
|
|
||||||
|
def test_subset_with_filters(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train"],
|
||||||
|
pick_categories=["cat1"],
|
||||||
|
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
|
||||||
|
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
|
||||||
|
limit_to=7, # retaining full train set of "cat1_seq1" and 2 from "cat1_seq2"
|
||||||
|
n_frames_per_sequence=3, # cutting "cat1_seq1" to 3 frames
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
|
||||||
|
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||||
|
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
|
||||||
|
self.assertEqual(seq_counts["cat1_seq1"], 3)
|
||||||
|
self.assertEqual(seq_counts["cat1_seq2"], 2)
|
||||||
|
|
||||||
|
def test_visitor(self):
|
||||||
|
dataset_sorted = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequences = dataset_sorted.sequence_names()
|
||||||
|
i = 0
|
||||||
|
for seq in sequences:
|
||||||
|
last_ts = float("-Inf")
|
||||||
|
for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq):
|
||||||
|
self.assertEqual(i, idx)
|
||||||
|
i += 1
|
||||||
|
self.assertGreaterEqual(ts, last_ts)
|
||||||
|
last_ts = ts
|
||||||
|
|
||||||
|
# test legacy visitor
|
||||||
|
old_indices = None
|
||||||
|
for seq in sequences:
|
||||||
|
last_ts = float("-Inf")
|
||||||
|
rows = dataset_sorted._index.index.get_loc(seq)
|
||||||
|
indices = list(range(rows.start or 0, rows.stop, rows.step or 1))
|
||||||
|
fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices)
|
||||||
|
self.assertEqual(len(fn_ts_list), len(indices))
|
||||||
|
|
||||||
|
if old_indices:
|
||||||
|
# check raising if we ask for multiple sequences
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
dataset_sorted.get_frame_numbers_and_timestamps(
|
||||||
|
indices + old_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
old_indices = indices
|
||||||
|
|
||||||
|
def test_visitor_subsets(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
limit_to=100, # force sorting
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequences = dataset.sequence_names()
|
||||||
|
i = 0
|
||||||
|
for seq in sequences:
|
||||||
|
last_ts = float("-Inf")
|
||||||
|
seq_frames = list(dataset.sequence_frames_in_order(seq))
|
||||||
|
self.assertEqual(len(seq_frames), 10)
|
||||||
|
for ts, _, idx in seq_frames:
|
||||||
|
self.assertEqual(i, idx)
|
||||||
|
i += 1
|
||||||
|
self.assertGreaterEqual(ts, last_ts)
|
||||||
|
last_ts = ts
|
||||||
|
|
||||||
|
last_ts = float("-Inf")
|
||||||
|
train_frames = list(dataset.sequence_frames_in_order(seq, "train"))
|
||||||
|
self.assertEqual(len(train_frames), 5)
|
||||||
|
for ts, _, _ in train_frames:
|
||||||
|
self.assertGreaterEqual(ts, last_ts)
|
||||||
|
last_ts = ts
|
||||||
|
|
||||||
|
def test_category_to_sequence_names(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
cat_to_seqs = dataset.category_to_sequence_names()
|
||||||
|
self.assertEqual(len(cat_to_seqs), 2)
|
||||||
|
self.assertIn("cat1", cat_to_seqs)
|
||||||
|
self.assertEqual(len(cat_to_seqs["cat1"]), 5)
|
||||||
|
|
||||||
|
# check that override preserves the behavior
|
||||||
|
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
|
||||||
|
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
|
||||||
|
|
||||||
|
def test_category_to_sequence_names_filters(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=True,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
exclude_sequences=["cat1_seq0"],
|
||||||
|
subsets=["train", "test"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
cat_to_seqs = dataset.category_to_sequence_names()
|
||||||
|
self.assertEqual(len(cat_to_seqs), 2)
|
||||||
|
self.assertIn("cat1", cat_to_seqs)
|
||||||
|
self.assertEqual(len(cat_to_seqs["cat1"]), 4) # minus one
|
||||||
|
|
||||||
|
# check that override preserves the behavior
|
||||||
|
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
|
||||||
|
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
|
||||||
|
|
||||||
|
def test_meta_access(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 50)
|
||||||
|
|
||||||
|
for idx in [10, ("cat0_seq2", 2)]:
|
||||||
|
example_meta = dataset.meta[idx]
|
||||||
|
example = dataset[idx]
|
||||||
|
self.assertEqual(example_meta.sequence_name, example.sequence_name)
|
||||||
|
self.assertEqual(example_meta.frame_number, example.frame_number)
|
||||||
|
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
|
||||||
|
self.assertEqual(example_meta.sequence_category, example.sequence_category)
|
||||||
|
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
|
||||||
|
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
example_meta.camera.focal_length, example.camera.focal_length
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
example_meta.camera.principal_point, example.camera.principal_point
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_meta_access_no_blobs(self):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
subset_lists_file=SET_LIST_FILE,
|
||||||
|
subsets=["train"],
|
||||||
|
frame_data_builder_FrameDataBuilder_args={
|
||||||
|
"dataset_root": ".",
|
||||||
|
"box_crop": False, # required by blob-less accessor
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNone(dataset.meta[0].image_rgb)
|
||||||
|
self.assertIsNone(dataset.meta[0].fg_probability)
|
||||||
|
self.assertIsNone(dataset.meta[0].depth_map)
|
||||||
|
self.assertIsNone(dataset.meta[0].sequence_point_cloud)
|
||||||
|
self.assertIsNotNone(dataset.meta[0].camera)
|
Loading…
x
Reference in New Issue
Block a user