SQL Index Dataset

Summary:
Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay.

It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS.

Reviewed By: bottler

Differential Revision: D45086611

fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
This commit is contained in:
Roman Shapovalov 2023-04-25 09:56:15 -07:00 committed by Facebook GitHub Bot
parent 7aeedd17a4
commit 32e1992924
10 changed files with 2309 additions and 6 deletions

View File

@ -450,6 +450,7 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
self, self,
frame_annotation: types.FrameAnnotation, frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation, sequence_annotation: types.SequenceAnnotation,
load_blobs: bool = True,
) -> FrameDataSubtype: ) -> FrameDataSubtype:
"""An abstract method to build the frame data based on raw frame/sequence """An abstract method to build the frame data based on raw frame/sequence
annotations, load the binary data and adjust them according to the metadata. annotations, load the binary data and adjust them according to the metadata.
@ -465,8 +466,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
Beware that modifications of frame data are done in-place. Beware that modifications of frame data are done in-place.
Args: Args:
dataset_root: The root folder of the dataset; all the paths in jsons are dataset_root: The root folder of the dataset; all paths in frame / sequence
specified relative to this root (but not json paths themselves). annotations are defined w.r.t. this root. Has to be set if any of the
load_* flabs below is true.
load_images: Enable loading the frame RGB data. load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps. load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the load_depth_masks: Enable loading the frame depth map masks denoting the
@ -494,7 +496,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
path_manager: Optionally a PathManager for interpreting paths in a special way. path_manager: Optionally a PathManager for interpreting paths in a special way.
""" """
dataset_root: str = "" dataset_root: Optional[str] = None
load_images: bool = True load_images: bool = True
load_depths: bool = True load_depths: bool = True
load_depth_masks: bool = True load_depth_masks: bool = True
@ -510,6 +512,25 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
box_crop_context: float = 0.3 box_crop_context: float = 0.3
path_manager: Any = None path_manager: Any = None
def __post_init__(self) -> None:
load_any_blob = (
self.load_images
or self.load_depths
or self.load_depth_masks
or self.load_masks
or self.load_point_clouds
)
if load_any_blob and self.dataset_root is None:
raise ValueError(
"dataset_root must be set to load any blob data. "
"Make sure it is set in either FrameDataBuilder or Dataset params."
)
if load_any_blob and not os.path.isdir(self.dataset_root): # pyre-ignore
raise ValueError(
f"dataset_root is passed but {self.dataset_root} does not exist."
)
def build( def build(
self, self,
frame_annotation: types.FrameAnnotation, frame_annotation: types.FrameAnnotation,
@ -567,7 +588,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
if bbox_xywh is None and fg_mask_np is not None: if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr) bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long) frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
if frame_annotation.image is not None: if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
@ -612,7 +633,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_fg_probability( def _load_fg_probability(
self, entry: types.FrameAnnotation self, entry: types.FrameAnnotation
) -> Tuple[np.ndarray, str]: ) -> Tuple[np.ndarray, str]:
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore assert self.dataset_root is not None and entry.mask is not None
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = load_mask(self._local_path(full_path)) fg_probability = load_mask(self._local_path(full_path))
if fg_probability.shape[-2:] != entry.image.size: if fg_probability.shape[-2:] != entry.image.size:
raise ValueError( raise ValueError(
@ -647,7 +669,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
fg_probability: Optional[torch.Tensor], fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]: ) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth entry_depth = entry.depth
assert entry_depth is not None assert self.dataset_root is not None and entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path) path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
@ -657,6 +679,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
if self.load_depth_masks: if self.load_depth_masks:
assert entry_depth.mask_path is not None assert entry_depth.mask_path is not None
# pyre-ignore
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = load_depth_mask(self._local_path(mask_path)) depth_mask = load_depth_mask(self._local_path(mask_path))
else: else:
@ -705,6 +728,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
) )
if path.startswith(unwanted_prefix): if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :] path = path[len(unwanted_prefix) :]
assert self.dataset_root is not None
return os.path.join(self.dataset_root, path) return os.path.join(self.dataset_root, path)
def _local_path(self, path: str) -> str: def _local_path(self, path: str) -> str:

View File

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This functionality requires SQLAlchemy 2.0 or later.
import math
import struct
from typing import Optional, Tuple
import numpy as np
from pytorch3d.implicitron.dataset.types import (
DepthAnnotation,
ImageAnnotation,
MaskAnnotation,
PointCloudAnnotation,
VideoAnnotation,
ViewpointAnnotation,
)
from sqlalchemy import LargeBinary
from sqlalchemy.orm import (
composite,
DeclarativeBase,
Mapped,
mapped_column,
MappedAsDataclass,
)
from sqlalchemy.types import TypeDecorator
# these produce policies to serialize structured types to blobs
def ArrayTypeFactory(shape):
class NumpyArrayType(TypeDecorator):
impl = LargeBinary
def process_bind_param(self, value, dialect):
if value is not None:
if value.shape != shape:
raise ValueError(f"Passed an array of wrong shape: {value.shape}")
return value.astype(np.float32).tobytes()
return None
def process_result_value(self, value, dialect):
if value is not None:
return np.frombuffer(value, dtype=np.float32).reshape(shape)
return None
return NumpyArrayType
def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
format_symbol = {
float: "f", # float32
int: "i", # int32
}[dtype]
class TupleType(TypeDecorator):
impl = LargeBinary
_format = format_symbol * math.prod(shape)
def process_bind_param(self, value, _):
if value is None:
return None
if len(shape) > 1:
value = np.array(value, dtype=dtype).reshape(-1)
return struct.pack(TupleType._format, *value)
def process_result_value(self, value, _):
if value is None:
return None
loaded = struct.unpack(TupleType._format, value)
if len(shape) > 1:
loaded = _rec_totuple(
np.array(loaded, dtype=dtype).reshape(shape).tolist()
)
return loaded
return TupleType
def _rec_totuple(t):
if isinstance(t, list):
return tuple(_rec_totuple(x) for x in t)
return t
class Base(MappedAsDataclass, DeclarativeBase):
"""subclasses will be converted to dataclasses"""
class SqlFrameAnnotation(Base):
__tablename__ = "frame_annots"
sequence_name: Mapped[str] = mapped_column(primary_key=True)
frame_number: Mapped[int] = mapped_column(primary_key=True)
frame_timestamp: Mapped[float] = mapped_column(index=True)
image: Mapped[ImageAnnotation] = composite(
mapped_column("_image_path"),
mapped_column("_image_size", TupleTypeFactory(int)),
)
depth: Mapped[DepthAnnotation] = composite(
mapped_column("_depth_path", nullable=True),
mapped_column("_depth_scale_adjustment", nullable=True),
mapped_column("_depth_mask_path", nullable=True),
)
mask: Mapped[MaskAnnotation] = composite(
mapped_column("_mask_path", nullable=True),
mapped_column("_mask_mass", index=True, nullable=True),
mapped_column(
"_mask_bounding_box_xywh",
TupleTypeFactory(float, shape=(4,)),
nullable=True,
),
)
viewpoint: Mapped[ViewpointAnnotation] = composite(
mapped_column(
"_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True
),
mapped_column(
"_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True
),
mapped_column(
"_viewpoint_focal_length", TupleTypeFactory(float), nullable=True
),
mapped_column(
"_viewpoint_principal_point", TupleTypeFactory(float), nullable=True
),
mapped_column("_viewpoint_intrinsics_format", nullable=True),
)
class SqlSequenceAnnotation(Base):
__tablename__ = "sequence_annots"
sequence_name: Mapped[str] = mapped_column(primary_key=True)
category: Mapped[str] = mapped_column(index=True)
video: Mapped[VideoAnnotation] = composite(
mapped_column("_video_path", nullable=True),
mapped_column("_video_length", nullable=True),
)
point_cloud: Mapped[PointCloudAnnotation] = composite(
mapped_column("_point_cloud_path", nullable=True),
mapped_column("_point_cloud_quality_score", nullable=True),
mapped_column("_point_cloud_n_points", nullable=True),
)
# the bigger the better
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)

View File

@ -0,0 +1,735 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import json
import logging
import os
from dataclasses import dataclass
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import numpy as np
import pandas as pd
import sqlalchemy as sa
import torch
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
FrameData,
FrameDataBuilder,
FrameDataBuilderBase,
)
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from sqlalchemy.orm import Session
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
logger = logging.getLogger(__name__)
_SET_LISTS_TABLE: str = "set_lists"
@registry.register
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
"""
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
The length is returned after all sequence and frame filters are applied (see param
definitions below). Indices can either be ordinal in [0, len), or pairs of
(sequence_name, frame_number); with the performance of `dataset[i]` and
`dataset[sequence_name, frame_number]` being same. A faster way to get metadata only
(without blobs) is `dataset.meta[idx]` indexing; it requires box_crop==False.
With ordinal indexing, the sequences are NOT guaranteed to span contiguous index
ranges, and frame numbers are NOT guaranteed to be increasing within a sequence.
Sequence-aware batch samplers have to use `sequence_[frames|indices]_in_order`
iterators, which are efficient.
This functionality requires SQLAlchemy 2.0 or later.
Metadata-related args:
sqlite_metadata_file: A SQLite file containing frame and sequence annotation
tables (mapping to SqlFrameAnnotation and SqlSequenceAnnotation,
respectively).
dataset_root: A root directory to look for images, masks, etc. It can be
alternatively set in `frame_data_builder` args, but this takes precedence.
subset_lists_file: A JSON/sqlite file containing the lists of frames
corresponding to different subsets (e.g. train/val/test) of the dataset;
format: {subset: [(sequence_name, frame_id, file_path)]}. All entries
must be present in frame_annotation metadata table.
path_manager: a facade for non-POSIX filesystems.
subsets: Restrict frames/sequences only to the given list of subsets
as defined in subset_lists_file (see above). Applied before all other
filters.
remove_empty_masks: Removes the frames with no active foreground pixels
in the segmentation mask (needs frame_annotation.mask.mass to be set;
null values are retained).
pick_frames_sql_clause: SQL WHERE clause to constrain frame annotations
NOTE: This is a potential security risk! The string is passed to the SQL
engine verbatim. Dont expose it to end users of your application!
pick_categories: Restrict the dataset to the given list of categories.
pick_sequences: A Sequence of sequence names to restrict the dataset to.
exclude_sequences: A Sequence of the names of the sequences to exclude.
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
sequences (after other sequence filters have been applied but before
frame-based filters).
limit_to: Limit the dataset to the first #limit_to frames (after other
filters have been applied, except n_frames_per_sequence).
n_frames_per_sequence: If > 0, randomly samples `n_frames_per_sequence`
frames in each sequences uniformly without replacement if it has
more frames than that; applied after other frame-level filters.
seed: The seed of the random generator sampling `n_frames_per_sequence`
random frames per sequence.
"""
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
sqlite_metadata_file: str = ""
dataset_root: Optional[str] = None
subset_lists_file: str = ""
eval_batches_file: Optional[str] = None
path_manager: Any = None
subsets: Optional[List[str]] = None
remove_empty_masks: bool = True
pick_frames_sql_clause: Optional[str] = None
pick_categories: Tuple[str, ...] = ()
pick_sequences: Tuple[str, ...] = ()
exclude_sequences: Tuple[str, ...] = ()
limit_sequences_to: int = 0
limit_to: int = 0
n_frames_per_sequence: int = -1
seed: int = 0
remove_empty_masks_poll_whole_table_threshold: int = 300_000
# we set it manually in the constructor
# _index: pd.DataFrame = field(init=False)
frame_data_builder: FrameDataBuilderBase
frame_data_builder_class_type: str = "FrameDataBuilder"
def __post_init__(self) -> None:
if sa.__version__ < "2.0":
raise ImportError("This class requires SQL Alchemy 2.0 or later")
if not self.sqlite_metadata_file:
raise ValueError("sqlite_metadata_file must be set")
if self.dataset_root:
frame_builder_type = self.frame_data_builder_class_type
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
"dataset_root"
] = self.dataset_root
run_auto_creation(self)
# pyre-ignore
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}")
sequences = self._get_filtered_sequences_if_any()
if self.subsets:
index = self._build_index_from_subset_lists(sequences)
else:
# TODO: if self.subset_lists_file and not self.subsets, it might be faster to
# still use the concatenated lists, assuming they cover the whole dataset
index = self._build_index_from_db(sequences)
if self.n_frames_per_sequence >= 0:
index = self._stratified_sample_index(index)
if len(index) == 0:
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
self.eval_batches = None # pyre-ignore
if self.eval_batches_file:
self.eval_batches = self._load_filter_eval_batches()
logger.info(str(self))
def __len__(self) -> int:
# pyre-ignore[16]
return len(self._index)
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
"""
Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair
"""
return self._get_item(frame_idx, True)
@property
def meta(self):
"""
Allows accessing metadata only without loading blobs using `dataset.meta[idx]`.
Requires box_crop==False, since in that case, cameras cannot be adjusted
without loading masks.
Returns:
FrameData objects with blob fields like `image_rgb` set to None.
Raises:
ValueError if dataset.box_crop is set.
"""
return SqlIndexDataset._MetadataAccessor(self)
@dataclass
class _MetadataAccessor:
dataset: "SqlIndexDataset"
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
return self.dataset._get_item(frame_idx, False)
def _get_item(
self, frame_idx: Union[int, Tuple[str, int]], load_blobs: bool = True
) -> FrameData:
if isinstance(frame_idx, int):
if frame_idx >= len(self._index):
raise IndexError(f"index {frame_idx} out of range {len(self._index)}")
seq, frame = self._index.index[frame_idx]
else:
seq, frame, *rest = frame_idx
if (seq, frame) not in self._index.index:
raise IndexError(
f"Sequence-frame index {frame_idx} not found; was it filtered out?"
)
if rest and rest[0] != self._index.loc[(seq, frame), "_image_path"]:
raise IndexError(f"Non-matching image path in {frame_idx}.")
stmt = sa.select(self.frame_annotations_type).where(
self.frame_annotations_type.sequence_name == seq,
self.frame_annotations_type.frame_number
== int(frame), # cast from np.int64
)
seq_stmt = sa.select(SqlSequenceAnnotation).where(
SqlSequenceAnnotation.sequence_name == seq
)
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
frame_data = self.frame_data_builder.build(
entry, seq_metadata, load_blobs=load_blobs
)
# The rest of the fields are optional
frame_data.frame_type = self._get_frame_type(entry)
return frame_data
def __str__(self) -> str:
# pyre-ignore[16]
return f"SqlIndexDataset #frames={len(self._index)}"
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self._index.index.unique("sequence_name")
# override
def category_to_sequence_names(self) -> Dict[str, List[str]]:
stmt = sa.select(
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
).where( # we limit results to sequences that have frames after all filters
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
)
with self._sql_engine.connect() as connection:
cat_to_seqs = pd.read_sql(stmt, connection)
return cat_to_seqs.groupby("category")["sequence_name"].apply(list).to_dict()
# override
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
) -> List[Tuple[int, float]]:
"""
Implements the DatasetBase method.
NOTE: Avoid this function as there are more efficient alternatives such as
querying `dataset[idx]` directly or getting all sequence frames with
`sequence_[frames|indices]_in_order`.
Return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. They need to belong to the same sequence!
If timestamps are absent, they are replaced with zeros.
This is used for letting SceneBatchSampler identify consecutive
frames.
Args:
idxs: a sequence int frame index in the dataset (it can be a slice)
subset_filter: must remain None
Returns:
list of tuples of
- frame index in video
- timestamp of frame in video, coalesced with 0s
Raises:
ValueError if idxs belong to more than one sequence.
"""
if subset_filter is not None:
raise NotImplementedError(
"Subset filters are not supported in SQL Dataset. "
"We encourage creating a dataset per subset."
)
index_slice, _ = self._get_frame_no_coalesced_ts_by_row_indices(idxs)
# alternatively, we can use `.values.tolist()`, which may be faster
# but returns a list of lists
return list(index_slice.itertuples())
# override
def sequence_frames_in_order(
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[Tuple[float, int, int]]:
"""
Overrides the default DatasetBase implementation (we dont use `_seq_to_idx`).
Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
subset_filter: subset names to filter to
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
# TODO: implement sort_timestamp_first? (which would matter if the orders
# of frame numbers and timestamps are different)
rows = self._index.index.get_loc(seq_name)
if isinstance(rows, slice):
assert rows.stop is not None, "Unexpected result from pandas"
rows = range(rows.start or 0, rows.stop, rows.step or 1)
else:
rows = np.where(rows)[0]
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
rows, seq_name, subset_filter
)
index_slice["idx"] = idx
yield from index_slice.itertuples(index=False)
# override
def get_eval_batches(self) -> Optional[List[Any]]:
"""
This class does not support eval batches with ordinal indices. You can pass
eval_batches as a batch_sampler to a data_loader since the dataset supports
`dataset[seq_name, frame_no]` indexing.
"""
return self.eval_batches
# override
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
raise ValueError("Not supported! Preprocess the data by merging them instead.")
# override
@property
def frame_data_type(self) -> Type[FrameData]:
return self.frame_data_builder.frame_data_type
def is_filtered(self) -> bool:
"""
Returns `True` in case the dataset has been filtered and thus some frame
annotations stored on the disk might be missing in the dataset object.
Does not account for subsets.
Returns:
is_filtered: `True` if the dataset has been filtered, else `False`.
"""
return (
self.remove_empty_masks
or self.limit_to > 0
or self.limit_sequences_to > 0
or len(self.pick_sequences) > 0
or len(self.exclude_sequences) > 0
or len(self.pick_categories) > 0
or self.n_frames_per_sequence > 0
)
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# maximum possible query: WHERE category IN 'self.pick_categories'
# AND sequence_name IN 'self.pick_sequences'
# AND sequence_name NOT IN 'self.exclude_sequences'
# LIMIT 'self.limit_sequence_to'
stmt = sa.select(SqlSequenceAnnotation.sequence_name)
where_conditions = [
*self._get_category_filters(),
*self._get_pick_filters(),
*self._get_exclude_filters(),
]
if where_conditions:
stmt = stmt.where(*where_conditions)
if self.limit_sequences_to > 0:
logger.info(
f"Limiting dataset to first {self.limit_sequences_to} sequences"
)
# NOTE: ROWID is SQLite-specific
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
if not where_conditions and self.limit_sequences_to <= 0:
# we will not need to filter by sequences
return None
with self._sql_engine.connect() as connection:
sequences = pd.read_sql_query(stmt, connection)["sequence_name"]
logger.info("... retained %d sequences" % len(sequences))
return sequences
def _get_category_filters(self) -> List[sa.ColumnElement]:
if not self.pick_categories:
return []
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
def _get_pick_filters(self) -> List[sa.ColumnElement]:
if not self.pick_sequences:
return []
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
if not self.exclude_sequences:
return []
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
assert self.subsets is not None
with open(subset_lists_path, "r") as f:
subset_to_seq_frame = json.load(f)
seq_frame_list = sum(
(
[(*row, subset) for row in subset_to_seq_frame[subset]]
for subset in self.subsets
),
[],
)
index = pd.DataFrame(
seq_frame_list,
columns=["sequence_name", "frame_number", "_image_path", "subset"],
)
return index
def _load_subsets_from_sql(self, subset_lists_path: str) -> pd.DataFrame:
subsets = self.subsets
assert subsets is not None
# we need a new engine since we store the subsets in a separate DB
engine = sa.create_engine(f"sqlite:///{subset_lists_path}")
table = sa.Table(_SET_LISTS_TABLE, sa.MetaData(), autoload_with=engine)
stmt = sa.select(table).where(table.c.subset.in_(subsets))
with engine.connect() as connection:
index = pd.read_sql(stmt, connection)
return index
def _build_index_from_subset_lists(
self, sequences: Optional[pd.Series]
) -> pd.DataFrame:
if not self.subset_lists_file:
raise ValueError("Requested subsets but subset_lists_file not given")
logger.info(f"Loading subset lists from {self.subset_lists_file}.")
subset_lists_path = self._local_path(self.subset_lists_file)
if subset_lists_path.lower().endswith(".json"):
index = self._load_subsets_from_json(subset_lists_path)
else:
index = self._load_subsets_from_sql(subset_lists_path)
index = index.set_index(["sequence_name", "frame_number"])
logger.info(f" -> loaded {len(index)} samples of {self.subsets}.")
if sequences is not None:
logger.info("Applying filtered sequences.")
sequence_values = index.index.get_level_values("sequence_name")
index = index.loc[sequence_values.isin(sequences)]
logger.info(f" -> retained {len(index)} samples.")
pick_frames_criteria = []
if self.remove_empty_masks:
logger.info("Culling samples with empty masks.")
if len(index) > self.remove_empty_masks_poll_whole_table_threshold:
# APPROACH 1: find empty masks and drop indices.
# dev load: 17s / 15 s (3.1M / 500K)
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
).where(self.frame_annotations_type._mask_mass == 0)
with Session(self._sql_engine) as session:
to_remove = session.execute(stmt).all()
# Pandas uses np.int64 for integer types, so we have to case
# we might want to read it to pandas DataFrame directly to avoid the loop
to_remove = [(seq, np.int64(fr)) for seq, fr in to_remove]
index.drop(to_remove, errors="ignore", inplace=True)
else:
# APPROACH 3: load index into a temp table and join with annotations
# dev load: 94 s / 23 s (3.1M / 500K)
pick_frames_criteria.append(
sa.or_(
self.frame_annotations_type._mask_mass.is_(None),
self.frame_annotations_type._mask_mass != 0,
)
)
if self.pick_frames_sql_clause:
logger.info("Applying the custom SQL clause.")
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
if pick_frames_criteria:
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
logger.info(f" -> retained {len(index)} samples.")
if self.limit_to > 0:
logger.info(f"Limiting dataset to first {self.limit_to} frames")
index = index.sort_index().iloc[: self.limit_to]
return index.reset_index()
def _pick_frames_by_criteria(self, index: pd.DataFrame, criteria) -> pd.DataFrame:
IndexTable = self._get_temp_index_table_instance()
with self._sql_engine.connect() as connection:
IndexTable.create(connection)
# we dont let pandass `to_sql` create the table automatically as
# the table would be permanent, so we create it and append with pandas
n_rows = index.to_sql(IndexTable.name, connection, if_exists="append")
assert n_rows == len(index)
sa_type = self.frame_annotations_type
stmt = (
sa.select(IndexTable)
.select_from(
IndexTable.join(
self.frame_annotations_type,
sa.and_(
sa_type.sequence_name == IndexTable.c.sequence_name,
sa_type.frame_number == IndexTable.c.frame_number,
),
)
)
.where(*criteria)
)
return pd.read_sql_query(stmt, connection).set_index(
["sequence_name", "frame_number"]
)
def _build_index_from_db(self, sequences: Optional[pd.Series]):
logger.info("Loading sequcence-frame index from the database")
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
self.frame_annotations_type._image_path,
sa.null().label("subset"),
)
where_conditions = []
if sequences is not None:
logger.info(" applying filtered sequences")
where_conditions.append(
self.frame_annotations_type.sequence_name.in_(sequences.tolist())
)
if self.remove_empty_masks:
logger.info(" excluding samples with empty masks")
where_conditions.append(
sa.or_(
self.frame_annotations_type._mask_mass.is_(None),
self.frame_annotations_type._mask_mass != 0,
)
)
if self.pick_frames_sql_clause:
logger.info(" applying custom SQL clause")
where_conditions.append(sa.text(self.pick_frames_sql_clause))
if where_conditions:
stmt = stmt.where(*where_conditions)
if self.limit_to > 0:
logger.info(f"Limiting dataset to first {self.limit_to} frames")
stmt = stmt.order_by(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
).limit(self.limit_to)
with self._sql_engine.connect() as connection:
index = pd.read_sql_query(stmt, connection)
logger.info(f" -> loaded {len(index)} samples.")
return index
def _sort_index_(self, index):
logger.info("Sorting the index by sequence and frame number.")
index.sort_values(["sequence_name", "frame_number"], inplace=True)
logger.info(" -> Done.")
def _load_filter_eval_batches(self):
assert self.eval_batches_file
logger.info(f"Loading eval batches from {self.eval_batches_file}")
if not os.path.isfile(self.eval_batches_file):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError(
f"Looking for dataset json file in {self.eval_batches_file}. "
+ "Please specify a correct dataset_root folder."
)
with open(self.eval_batches_file, "r") as f:
eval_batches = json.load(f)
# limit the dataset to sequences to allow multiple evaluations in one file
pick_sequences = set(self.pick_sequences)
if self.pick_categories:
cat_to_seq = self.category_to_sequence_names()
pick_sequences.update(
seq for cat in self.pick_categories for seq in cat_to_seq[cat]
)
if pick_sequences:
old_len = len(eval_batches)
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
logger.warn(
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
)
if self.exclude_sequences:
old_len = len(eval_batches)
exclude_sequences = set(self.exclude_sequences)
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
logger.warn(
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
)
return eval_batches
def _stratified_sample_index(self, index):
# NOTE this stratified sampling can be done more efficiently in
# the no-subset case above if it is added to the SQL query.
# We keep this generic implementation since no-subset case is uncommon
index = index.groupby("sequence_name", group_keys=False).apply(
lambda seq_frames: seq_frames.sample(
min(len(seq_frames), self.n_frames_per_sequence),
random_state=(
_seq_name_to_seed(seq_frames.iloc[0]["sequence_name"]) + self.seed
),
)
)
logger.info(f" -> retained {len(index)} samples aster stratified sampling.")
return index
def _get_frame_type(self, entry: SqlFrameAnnotation) -> Optional[str]:
return self._index.loc[(entry.sequence_name, entry.frame_number), "subset"]
def _get_frame_no_coalesced_ts_by_row_indices(
self,
idxs: Sequence[int],
seq_name: Optional[str] = None,
subset_filter: Union[Sequence[str], str, None] = None,
) -> Tuple[pd.DataFrame, Sequence[int]]:
"""
Loads timestamps for given index rows belonging to the same sequence.
If seq_name is known, it speeds up the computation.
Raises ValueError if `idxs` do not all belong to a single sequences .
"""
index_slice = self._index.iloc[idxs]
if subset_filter is not None:
if isinstance(subset_filter, str):
subset_filter = [subset_filter]
indicator = index_slice["subset"].isin(subset_filter)
index_slice = index_slice.loc[indicator]
idxs = [i for i, isin in zip(idxs, indicator) if isin]
frames = index_slice.index.get_level_values("frame_number").tolist()
if seq_name is None:
seq_name_list = index_slice.index.get_level_values("sequence_name").tolist()
seq_name_set = set(seq_name_list)
if len(seq_name_set) > 1:
raise ValueError("Given indices belong to more than one sequence.")
elif len(seq_name_set) == 1:
seq_name = seq_name_list[0]
coalesced_ts = sa.sql.functions.coalesce(
self.frame_annotations_type.frame_timestamp, 0
)
stmt = sa.select(
coalesced_ts.label("frame_timestamp"),
self.frame_annotations_type.frame_number,
).where(
self.frame_annotations_type.sequence_name == seq_name,
self.frame_annotations_type.frame_number.in_(frames),
)
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
if len(frame_no_ts) != len(index_slice):
raise ValueError(
"Not all indices are found in the database; "
"do they belong to more than one sequence?"
)
return frame_no_ts, idxs
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
return self.path_manager.get_local_path(path)
def _get_temp_index_table_instance(self, table_name: str = "__index"):
CachedTable = self.frame_annotations_type.metadata.tables.get(table_name)
if CachedTable is not None: # table definition is not idempotent
return CachedTable
return sa.Table(
table_name,
self.frame_annotations_type.metadata,
sa.Column("sequence_name", sa.String, primary_key=True),
sa.Column("frame_number", sa.Integer, primary_key=True),
sa.Column("_image_path", sa.String),
sa.Column("subset", sa.String),
prefixes=["TEMP"], # NOTE SQLite specific!
)
def _seq_name_to_seed(seq_name) -> int:
"""Generates numbers in [0, 2 ** 28)"""
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
def _safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None

View File

@ -0,0 +1,424 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from typing import List, Optional, Tuple, Type
import numpy as np
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
)
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry,
run_auto_creation,
)
from .sql_dataset import SqlIndexDataset
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
# _NEED_CONTROL is a list of those elements of SqlIndexDataset which
# are not directly specified for it in the config but come from the
# DatasetMapProvider.
_NEED_CONTROL: Tuple[str, ...] = (
"path_manager",
"subsets",
"sqlite_metadata_file",
"subset_lists_file",
)
logger = logging.getLogger(__name__)
@registry.register
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
"""
Generates the training, validation, and testing dataset objects for
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
The dataset is organized in the filesystem as follows::
self.dataset_root
<possible/partition/0>
<sequence_name_0>
depth_masks
depths
images
masks
pointcloud.ply
<sequence_name_1>
depth_masks
depths
images
masks
pointcloud.ply
...
<sequence_name_N>
set_lists
<subset_base_name_0>.json
<subset_base_name_1>.json
...
<subset_base_name_2>.json
eval_batches
<eval_batches_base_name_0>.json
<eval_batches_base_name_1>.json
...
<eval_batches_base_name_M>.json
frame_annotations.jgz
sequence_annotations.jgz
<possible/partition/1>
...
<possible/partition/K>
set_lists
<subset_base_name_0>.sqlite
<subset_base_name_1>.sqlite
...
<subset_base_name_2>.sqlite
eval_batches
<eval_batches_base_name_0>.json
<eval_batches_base_name_1>.json
...
<eval_batches_base_name_M>.json
The dataset contains sequences named `<sequence_name_i>` that may be partitioned by
directories such as `<possible/partition/0>` e.g. representing categories but they
can also be stored in a flat structure. Each sequence folder contains the list of
sequence images, depth maps, foreground masks, and valid-depth masks
`images`, `depths`, `masks`, and `depth_masks` respectively. Furthermore,
`set_lists/` dirtectories (with partitions or global) store json or sqlite files
`<subset_base_name_l>.<ext>`, each describing a certain sequence subset.
These subset path conventions are not hard-coded and arbitrary relative path can be
specified by setting `self.subset_lists_path` to the relative path w.r.t.
dataset root.
Each `<subset_base_name_l>.json` file contains the following dictionary::
{
"train": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
"val": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
"test": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
]
defining the list of frames (identified with their `sequence_name` and
`frame_number`) in the "train", "val", and "test" subsets of the dataset. In case of
SQLite format, `<subset_base_name_l>.sqlite` contains a table with the header::
| sequence_name | frame_number | image_path | subset |
Note that `frame_number` can be obtained only from the metadata and
does not necesarrily correspond to the numeric suffix of the corresponding image
file name (e.g. a file `<partition_0>/<sequence_name_0>/images/frame00005.jpg` can
have its frame number set to `20`, not 5).
Each `<eval_batches_base_name_M>.json` file contains a list of evaluation examples
in the following form::
[
[ # batch 1
(sequence_name: str, frame_number: int, image_path: str),
...
],
[ # batch 2
(sequence_name: str, frame_number: int, image_path: str),
...
],
]
Note that the evaluation examples always come from the `"test"` subset of the dataset.
(test frames can repeat across batches). The batches can contain single element,
which is typical in case of regular radiance field fitting.
Args:
subset_lists_path: The relative path to the dataset subset definition.
For CO3D, these include e.g. "skateboard/set_lists/set_lists_manyview_dev_0.json".
By default (None), dataset is not partitioned to subsets (in that case, setting
`ignore_subsets` will speed up construction)
dataset_root: The root folder of the dataset.
metadata_basename: name of the SQL metadata file in dataset_root;
not expected to be changed by users
test_on_train: Construct validation and test datasets from
the training subset; note that in practice, in this
case all subset dataset objects will be same
only_test_set: Load only the test set. Incompatible with `test_on_train`.
ignore_subsets: Dont filter by subsets in the dataset; note that in this
case all subset datasets will be same
eval_batch_num_training_frames: Add a certain number of training frames to each
eval batch. Useful for evaluating models that require
source views as input (e.g. NeRF-WCE / PixelNeRF).
dataset_args: Specifies additional arguments to the
JsonIndexDataset constructor call.
path_manager_factory: (Optional) An object that generates an instance of
PathManager that can translate provided file paths.
path_manager_factory_class_type: The class type of `path_manager_factory`.
"""
category: Optional[str] = None
subset_list_name: Optional[str] = None # TODO: docs
# OR
subset_lists_path: Optional[str] = None
eval_batches_path: Optional[str] = None
dataset_root: str = _CO3D_SQL_DATASET_ROOT
metadata_basename: str = "metadata.sqlite"
test_on_train: bool = False
only_test_set: bool = False
ignore_subsets: bool = False
train_subsets: Tuple[str, ...] = ("train",)
val_subsets: Tuple[str, ...] = ("val",)
test_subsets: Tuple[str, ...] = ("test",)
eval_batch_num_training_frames: int = 0
# this is a mould that is never constructed, used to build self._dataset_map values
dataset_class_type: str = "SqlIndexDataset"
dataset: SqlIndexDataset
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self):
super().__init__()
run_auto_creation(self)
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
if self.ignore_subsets and not self.only_test_set:
self.test_on_train = True # no point in loading same data 3 times
path_manager = self.path_manager_factory.get()
sqlite_metadata_file = os.path.join(self.dataset_root, self.metadata_basename)
sqlite_metadata_file = _local_path(path_manager, sqlite_metadata_file)
if not os.path.isfile(sqlite_metadata_file):
# The sqlite_metadata_file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError(
f"Looking for frame annotations in {sqlite_metadata_file}."
+ " Please specify a correct dataset_root folder."
+ " Note: By default the root folder is taken from the"
+ " CO3D_SQL_DATASET_ROOT environment variable."
)
if self.subset_lists_path and self.subset_list_name:
raise ValueError(
"subset_lists_path and subset_list_name cannot be both set"
)
subset_lists_file = self._get_lists_file("set_lists")
# setup the common dataset arguments
common_dataset_kwargs = {
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
"sqlite_metadata_file": sqlite_metadata_file,
"dataset_root": self.dataset_root,
"subset_lists_file": subset_lists_file,
"path_manager": path_manager,
}
if self.category:
logger.info(f"Forcing category filter in the datasets to {self.category}")
common_dataset_kwargs["pick_categories"] = self.category.split(",")
# get the used dataset type
dataset_type: Type[SqlIndexDataset] = registry.get(
SqlIndexDataset, self.dataset_class_type
)
expand_args_fields(dataset_type)
if subset_lists_file is not None and not os.path.isfile(subset_lists_file):
available_subsets = self._get_available_subsets(
OmegaConf.to_object(common_dataset_kwargs["pick_categories"])
)
msg = f"Cannot find subset list file {self.subset_lists_path}."
if available_subsets:
msg += f" Some of the available subsets: {str(available_subsets)}."
raise ValueError(msg)
train_dataset = None
val_dataset = None
if not self.only_test_set:
# load the training set
logger.debug("Constructing train dataset.")
train_dataset = dataset_type(
**common_dataset_kwargs, subsets=self._get_subsets(self.train_subsets)
)
logger.info(f"Train dataset: {str(train_dataset)}")
if self.test_on_train:
assert train_dataset is not None
val_dataset = test_dataset = train_dataset
else:
# load the val and test sets
if not self.only_test_set:
# NOTE: this is always loaded in JsonProviderV2
logger.debug("Extracting val dataset.")
val_dataset = dataset_type(
**common_dataset_kwargs, subsets=self._get_subsets(self.val_subsets)
)
logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.")
eval_batches_file = self._get_lists_file("eval_batches")
del common_dataset_kwargs["eval_batches_file"]
test_dataset = dataset_type(
**common_dataset_kwargs,
subsets=self._get_subsets(self.test_subsets, True),
eval_batches_file=eval_batches_file,
)
logger.info(f"Test dataset: {str(test_dataset)}")
if (
eval_batches_file is not None
and self.eval_batch_num_training_frames > 0
):
self._extend_eval_batches(test_dataset)
self._dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)
def _get_subsets(self, subsets, is_eval: bool = False):
if self.ignore_subsets:
return None
if is_eval and self.eval_batch_num_training_frames > 0:
# we will need to have training frames for extended batches
return list(subsets) + list(self.train_subsets)
return subsets
def _extend_eval_batches(self, test_dataset: SqlIndexDataset) -> None:
rng = np.random.default_rng(seed=0)
eval_batches = test_dataset.get_eval_batches()
if eval_batches is None:
raise ValueError("Eval batches were not loaded!")
for batch in eval_batches:
sequence = batch[0][0]
seq_frames = list(
test_dataset.sequence_frames_in_order(sequence, self.train_subsets)
)
idx_to_add = rng.permutation(len(seq_frames))[
: self.eval_batch_num_training_frames
]
batch.extend((sequence, seq_frames[a][1]) for a in idx_to_add)
@classmethod
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
"""
Called by get_default_args.
Certain fields are not exposed on each dataset class
but rather are controlled by this provider class.
"""
for key in _NEED_CONTROL:
del args[key]
def create_dataset(self):
# No `dataset` member of this class is created.
# The dataset(s) live in `self.get_dataset_map`.
pass
def get_dataset_map(self) -> DatasetMap:
return self._dataset_map # pyre-ignore [16]
def _get_available_subsets(self, categories: List[str]):
"""
Get the available subset names for a given category folder (if given) inside
a root dataset folder `dataset_root`.
"""
path_manager = self.path_manager_factory.get()
subsets: List[str] = []
for prefix in [""] + categories:
set_list_dir = os.path.join(self.dataset_root, prefix, "set_lists")
if not (
(path_manager is not None) and path_manager.isdir(set_list_dir)
) and not os.path.isdir(set_list_dir):
continue
set_list_files = (os.listdir if path_manager is None else path_manager.ls)(
set_list_dir
)
subsets.extend(os.path.join(prefix, "set_lists", f) for f in set_list_files)
return subsets
def _get_lists_file(self, flavor: str) -> Optional[str]:
if flavor == "eval_batches":
subset_lists_path = self.eval_batches_path
else:
subset_lists_path = self.subset_lists_path
if not subset_lists_path and not self.subset_list_name:
return None
category_elem = ""
if self.category and "," not in self.category:
# if multiple categories are given, looking for global set lists
category_elem = self.category
subset_lists_path = subset_lists_path or (
os.path.join(
category_elem, f"{flavor}", f"{flavor}_{self.subset_list_name}"
)
)
assert subset_lists_path
path_manager = self.path_manager_factory.get()
# try absolute path first
subset_lists_file = _get_local_path_check_extensions(
subset_lists_path, path_manager
)
if subset_lists_file:
return subset_lists_file
full_path = os.path.join(self.dataset_root, subset_lists_path)
subset_lists_file = _get_local_path_check_extensions(full_path, path_manager)
if not subset_lists_file:
raise FileNotFoundError(
f"Subset lists path given but not found: {full_path}"
)
return subset_lists_file
def _get_local_path_check_extensions(
path, path_manager, extensions=("", ".sqlite", ".json")
) -> Optional[str]:
for ext in extensions:
local = _local_path(path_manager, path + ext)
if os.path.isfile(local):
return local
return None
def _local_path(path_manager, path: str) -> str:
if path_manager is None:
return path
return path_manager.get_local_path(path)

View File

@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DataLoaderMap,
SceneBatchSampler,
SequenceDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
# TODO: we can merge it with SequenceDataLoaderMapProvider in PyTorch3D
# and support both eval_batches protocols
@registry.register
class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider):
"""
Implementation of DataLoaderMapProviderBase that may use internal eval batches for
the test dataset. In particular, if `eval_batches_relpath` is set, it loads
eval batches from that json file, otherwise test set is treated in the same way as
train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type`
are respected.
If conditioning is not required, then the batch size should
be set as 1, and most of the fields do not matter.
If conditioning is required, each batch will contain one main
frame first to predict and the, rest of the elements are for
conditioning.
If images_per_seq_options is left empty, the conditioning
frames are picked according to the conditioning type given.
This does not have regard to the order of frames in a
scene, or which frames belong to what scene.
If images_per_seq_options is given, then the conditioning types
must be SAME and the remaining fields are used.
Members:
batch_size: The size of the batch of the data loader.
num_workers: Number of data-loading threads in each data loader.
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
an epoch is the length of the training set.
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
an epoch is the length of the validation set.
dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of
batches in a testing epoch. Or 0 to mean an epoch is the length of the test
set.
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
value. Empty (the default) means that we are not careful about which frames
come from which scene.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
"""
batch_size: int = 1
num_workers: int = 0
dataset_length_train: int = 0
dataset_length_val: int = 0
dataset_length_test: int = 0
images_per_seq_options: Tuple[int, ...] = ()
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1
def __post_init__(self):
run_auto_creation(self)
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
train = self._make_generic_data_loader(
datasets.train,
self.dataset_length_train,
datasets.train,
)
val = self._make_generic_data_loader(
datasets.val,
self.dataset_length_val,
datasets.train,
)
if datasets.test is not None and datasets.test.get_eval_batches() is not None:
test = self._make_eval_data_loader(datasets.test)
else:
test = self._make_generic_data_loader(
datasets.test,
self.dataset_length_test,
datasets.train,
)
return DataLoaderMap(train=train, val=val, test=test)
def _make_eval_data_loader(
self,
dataset: Optional[DatasetBase],
) -> Optional[DataLoader[FrameData]]:
if dataset is None:
return None
return DataLoader(
dataset,
batch_sampler=dataset.get_eval_batches(),
**self._get_data_loader_common_kwargs(dataset),
)
def _make_generic_data_loader(
self,
dataset: Optional[DatasetBase],
num_batches: int,
train_dataset: Optional[DatasetBase],
) -> Optional[DataLoader[FrameData]]:
"""
Returns the dataloader for a dataset.
Args:
dataset: the dataset
num_batches: possible ceiling on number of batches per epoch
train_dataset: the training dataset, used if conditioning_type==TRAIN
conditioning_type: source for padding of batches
"""
if dataset is None:
return None
data_loader_kwargs = self._get_data_loader_common_kwargs(dataset)
if len(self.images_per_seq_options) > 0:
# this is a typical few-view setup
# conditioning comes from the same subset since subsets are split by seqs
batch_sampler = SceneBatchSampler(
dataset,
self.batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=self.images_per_seq_options,
sample_consecutive_frames=self.sample_consecutive_frames,
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
)
return DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
if self.batch_size == 1:
# this is a typical many-view setup (without conditioning)
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
# edge case: conditioning on train subset, typical for Nerformer-like many-view
# there is only one sequence in all datasets, so we condition on another subset
return self._train_loader(
dataset, train_dataset, num_batches, data_loader_kwargs
)
def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]:
return {
"num_workers": self.num_workers,
"collate_fn": dataset.frame_data_type.collate,
}

View File

@ -164,6 +164,7 @@ setup(
"tqdm>4.29.0", "tqdm>4.29.0",
"matplotlib", "matplotlib",
"accelerate", "accelerate",
"sqlalchemy>=2.0",
], ],
}, },
entry_points={ entry_points={

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
import torch
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
SequenceDataLoaderMapProvider,
SimpleDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa
SqlIndexDatasetMapProvider,
)
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import (
TrainEvalDataLoaderMapProvider,
)
from pytorch3d.implicitron.tools.config import get_default_args
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
sh = logging.StreamHandler()
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available")
class TestCo3dSqlDataSource(unittest.TestCase):
def test_no_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.ignore_subsets = True
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.pick_categories = ["skateboard"]
dataset_args.limit_sequences_to = 1
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider
)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 202)
for frame in data_loaders.train:
self.assertIsNone(frame.frame_type)
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = (
"skateboard/set_lists/set_lists_manyview_dev_0.json"
)
# this will naturally limit to one sequence (no need to limit by cat/sequence)
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
for sampler_type in [
"SimpleDataLoaderMapProvider",
"SequenceDataLoaderMapProvider",
"TrainEvalDataLoaderMapProvider",
]:
args.data_loader_map_provider_class_type = sampler_type
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 100)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
# check loading blobs
self.assertEqual(frame.image_rgb.shape[-1], 800)
break
def test_sql_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
for sampler_type in [
"SimpleDataLoaderMapProvider",
"SequenceDataLoaderMapProvider",
"TrainEvalDataLoaderMapProvider",
]:
args.data_loader_map_provider_class_type = sampler_type
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 100)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(
frame.image_rgb.shape[-1], 800
) # check loading blobs
break
@unittest.skip("It takes 75 seconds; skipping by default")
def test_huge_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 3158974)
self.assertEqual(len(data_loaders.val), 518417)
self.assertEqual(len(data_loaders.test), 518417)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_broken_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "et_non_est"
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"]
with self.assertRaises(FileNotFoundError) as err:
ImplicitronDataSource(**args)
# check the hint text
self.assertIn("Subset lists path given but not found", str(err.exception))
def test_eval_batches(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
provider_args.eval_batches_path = (
"skateboard/eval_batches/eval_batches_manyview_dev_0.json"
)
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 50)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_eval_batches_from_subset_list_name(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_list_name = "manyview_dev_0"
provider_args.category = "skateboard"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
data_source = ImplicitronDataSource(**args)
dataset, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"])
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 50)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_frame_access(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args
frame_builder_args.load_point_clouds = True
frame_builder_args.box_crop = False # required for .meta
data_source = ImplicitronDataSource(**args)
dataset_map, _ = data_source.get_datasets_and_dataloaders()
dataset = dataset_map["train"]
for idx in [10, ("245_26182_52130", 22)]:
example_meta = dataset.meta[idx]
example = dataset[idx]
self.assertIsNone(example_meta.image_rgb)
self.assertIsNone(example_meta.fg_probability)
self.assertIsNone(example_meta.depth_map)
self.assertIsNone(example_meta.sequence_point_cloud)
self.assertIsNotNone(example_meta.camera)
self.assertIsNotNone(example.image_rgb)
self.assertIsNotNone(example.fg_probability)
self.assertIsNotNone(example.depth_map)
self.assertIsNotNone(example.sequence_point_cloud)
self.assertIsNotNone(example.camera)
self.assertEqual(example_meta.sequence_name, example.sequence_name)
self.assertEqual(example_meta.frame_number, example.frame_number)
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
self.assertEqual(example_meta.sequence_category, example.sequence_category)
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
torch.testing.assert_close(
example_meta.camera.focal_length, example.camera.focal_length
)
torch.testing.assert_close(
example_meta.camera.principal_point, example.camera.principal_point
)

View File

@ -0,0 +1,522 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
from collections import Counter
import pkg_resources
import torch
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
NO_BLOBS_KWARGS = {
"dataset_root": "",
"load_images": False,
"load_depths": False,
"load_masks": False,
"load_depth_masks": False,
"box_crop": False,
}
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
sh = logging.StreamHandler()
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json")
class TestSqlDataset(unittest.TestCase):
def test_basic(self, sequence="cat1_seq2", frame_number=4):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
# test indexing
with self.assertRaises(IndexError):
dataset[len(dataset) + 1]
# test sequence-frame indexing
item = dataset[sequence, frame_number]
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, frame_number)
with self.assertRaises(IndexError):
dataset[sequence, 13]
def test_filter_empty_masks(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 78)
def test_pick_frames_sql_clause(self):
dataset_no_empty_masks = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# check the datasets are equal
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
for i in range(len(dataset)):
item_nem = dataset_no_empty_masks[i]
item = dataset[i]
self.assertEqual(item_nem.image_path, item.image_path)
# remove_empty_masks together with the custom criterion
dataset_ts = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_frames_sql_clause="frame_timestamp < 0.15",
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset_ts), 19)
def test_limit_categories(self, category="cat0"):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_categories=[category],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
for i in range(len(dataset)):
self.assertEqual(dataset[i].sequence_category, category)
def test_limit_sequences(self, num_sequences=3):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_sequences_to=num_sequences,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 10 * num_sequences)
def delist(sequence_name):
return sequence_name if isinstance(sequence_name, str) else sequence_name[0]
unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))}
self.assertEqual(len(unique_seqs), num_sequences)
def test_pick_exclude_sequencess(self, sequence="cat1_seq2"):
# pick sequence
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_sequences=[sequence],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 10)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertCountEqual(unique_seqs, {sequence})
item = dataset[sequence, 0]
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, 0)
# exclude sequence
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
exclude_sequences=[sequence],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 90)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertNotIn(sequence, unique_seqs)
with self.assertRaises(IndexError):
dataset[sequence, 0]
def test_limit_frames(self, num_frames=13):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=num_frames,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), num_frames)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertEqual(len(unique_seqs), 2)
# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=1000,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
def test_limit_frames_per_sequence(self, num_frames=2):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
n_frames_per_sequence=num_frames,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), num_frames * 10)
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertEqual(len(seq_counts), 10)
self.assertCountEqual(
set(seq_counts.values()), {2}
) # all counts are num_frames
with self.assertRaises(IndexError):
dataset[next(iter(seq_counts)), num_frames + 1]
# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
n_frames_per_sequence=13,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
def test_filter_medley(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_categories=["cat1"],
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
limit_to=14, # retaining full "cat1_seq1" and 4 from "cat1_seq2"
n_frames_per_sequence=6, # cutting "cat1_seq1" to 6 frames
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
self.assertEqual(seq_counts["cat1_seq1"], 6)
self.assertEqual(seq_counts["cat1_seq2"], 4)
def test_subsets_trivial(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
limit_to=100, # force sorting
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
def test_subsets_filter_empty_masks(self):
# we need to test this case as it uses quite different logic with `df.drop()`
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 78)
def test_subsets_pick_frames_sql_clause(self):
dataset_no_empty_masks = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# check the datasets are equal
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
for i in range(len(dataset)):
item_nem = dataset_no_empty_masks[i]
item = dataset[i]
self.assertEqual(item_nem.image_path, item.image_path)
# remove_empty_masks together with the custom criterion
dataset_ts = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_frames_sql_clause="frame_timestamp < 0.15",
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset_ts), 19)
def test_single_subset(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
with self.assertRaises(IndexError):
dataset[51]
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number < 2:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 2)
last_frame_number = item.frame_number
item = dataset[last_sequence, 0]
self.assertEqual(item.sequence_name, last_sequence)
with self.assertRaises(IndexError):
dataset[last_sequence, 1]
def test_subset_with_filters(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
pick_categories=["cat1"],
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
limit_to=7, # retaining full train set of "cat1_seq1" and 2 from "cat1_seq2"
n_frames_per_sequence=3, # cutting "cat1_seq1" to 3 frames
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
self.assertEqual(seq_counts["cat1_seq1"], 3)
self.assertEqual(seq_counts["cat1_seq2"], 2)
def test_visitor(self):
dataset_sorted = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
sequences = dataset_sorted.sequence_names()
i = 0
for seq in sequences:
last_ts = float("-Inf")
for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq):
self.assertEqual(i, idx)
i += 1
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
# test legacy visitor
old_indices = None
for seq in sequences:
last_ts = float("-Inf")
rows = dataset_sorted._index.index.get_loc(seq)
indices = list(range(rows.start or 0, rows.stop, rows.step or 1))
fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices)
self.assertEqual(len(fn_ts_list), len(indices))
if old_indices:
# check raising if we ask for multiple sequences
with self.assertRaises(ValueError):
dataset_sorted.get_frame_numbers_and_timestamps(
indices + old_indices
)
old_indices = indices
def test_visitor_subsets(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=100, # force sorting
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
sequences = dataset.sequence_names()
i = 0
for seq in sequences:
last_ts = float("-Inf")
seq_frames = list(dataset.sequence_frames_in_order(seq))
self.assertEqual(len(seq_frames), 10)
for ts, _, idx in seq_frames:
self.assertEqual(i, idx)
i += 1
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
last_ts = float("-Inf")
train_frames = list(dataset.sequence_frames_in_order(seq, "train"))
self.assertEqual(len(train_frames), 5)
for ts, _, _ in train_frames:
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
def test_category_to_sequence_names(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
cat_to_seqs = dataset.category_to_sequence_names()
self.assertEqual(len(cat_to_seqs), 2)
self.assertIn("cat1", cat_to_seqs)
self.assertEqual(len(cat_to_seqs["cat1"]), 5)
# check that override preserves the behavior
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
def test_category_to_sequence_names_filters(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
exclude_sequences=["cat1_seq0"],
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
cat_to_seqs = dataset.category_to_sequence_names()
self.assertEqual(len(cat_to_seqs), 2)
self.assertIn("cat1", cat_to_seqs)
self.assertEqual(len(cat_to_seqs["cat1"]), 4) # minus one
# check that override preserves the behavior
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
def test_meta_access(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
for idx in [10, ("cat0_seq2", 2)]:
example_meta = dataset.meta[idx]
example = dataset[idx]
self.assertEqual(example_meta.sequence_name, example.sequence_name)
self.assertEqual(example_meta.frame_number, example.frame_number)
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
self.assertEqual(example_meta.sequence_category, example.sequence_category)
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
torch.testing.assert_close(
example_meta.camera.focal_length, example.camera.focal_length
)
torch.testing.assert_close(
example_meta.camera.principal_point, example.camera.principal_point
)
def test_meta_access_no_blobs(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args={
"dataset_root": ".",
"box_crop": False, # required by blob-less accessor
},
)
self.assertIsNone(dataset.meta[0].image_rgb)
self.assertIsNone(dataset.meta[0].fg_probability)
self.assertIsNone(dataset.meta[0].depth_map)
self.assertIsNone(dataset.meta[0].sequence_point_cloud)
self.assertIsNotNone(dataset.meta[0].camera)