From 32e1992924929a9b79e880ed6f5bdc74089e8c73 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Tue, 25 Apr 2023 09:56:15 -0700 Subject: [PATCH] SQL Index Dataset Summary: Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay. It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS. Reviewed By: bottler Differential Revision: D45086611 fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef --- pytorch3d/implicitron/dataset/frame_data.py | 36 +- pytorch3d/implicitron/dataset/orm_types.py | 161 ++++ pytorch3d/implicitron/dataset/sql_dataset.py | 735 ++++++++++++++++++ .../dataset/sql_dataset_provider.py | 424 ++++++++++ .../train_eval_data_loader_provider.py | 189 +++++ setup.py | 1 + .../data/sql_dataset/set_lists_100.json | 1 + .../data/sql_dataset/sql_dataset_100.sqlite | Bin 0 -> 81920 bytes tests/implicitron/test_co3d_sql.py | 246 ++++++ tests/implicitron/test_sql_dataset.py | 522 +++++++++++++ 10 files changed, 2309 insertions(+), 6 deletions(-) create mode 100644 pytorch3d/implicitron/dataset/orm_types.py create mode 100644 pytorch3d/implicitron/dataset/sql_dataset.py create mode 100644 pytorch3d/implicitron/dataset/sql_dataset_provider.py create mode 100644 pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py create mode 100644 tests/implicitron/data/sql_dataset/set_lists_100.json create mode 100644 tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite create mode 100644 tests/implicitron/test_co3d_sql.py create mode 100644 tests/implicitron/test_sql_dataset.py diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index e8e88b70..88455319 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -450,6 +450,7 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC): self, frame_annotation: types.FrameAnnotation, sequence_annotation: types.SequenceAnnotation, + load_blobs: bool = True, ) -> FrameDataSubtype: """An abstract method to build the frame data based on raw frame/sequence annotations, load the binary data and adjust them according to the metadata. @@ -465,8 +466,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): Beware that modifications of frame data are done in-place. Args: - dataset_root: The root folder of the dataset; all the paths in jsons are - specified relative to this root (but not json paths themselves). + dataset_root: The root folder of the dataset; all paths in frame / sequence + annotations are defined w.r.t. this root. Has to be set if any of the + load_* flabs below is true. load_images: Enable loading the frame RGB data. load_depths: Enable loading the frame depth maps. load_depth_masks: Enable loading the frame depth map masks denoting the @@ -494,7 +496,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): path_manager: Optionally a PathManager for interpreting paths in a special way. """ - dataset_root: str = "" + dataset_root: Optional[str] = None load_images: bool = True load_depths: bool = True load_depth_masks: bool = True @@ -510,6 +512,25 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): box_crop_context: float = 0.3 path_manager: Any = None + def __post_init__(self) -> None: + load_any_blob = ( + self.load_images + or self.load_depths + or self.load_depth_masks + or self.load_masks + or self.load_point_clouds + ) + if load_any_blob and self.dataset_root is None: + raise ValueError( + "dataset_root must be set to load any blob data. " + "Make sure it is set in either FrameDataBuilder or Dataset params." + ) + + if load_any_blob and not os.path.isdir(self.dataset_root): # pyre-ignore + raise ValueError( + f"dataset_root is passed but {self.dataset_root} does not exist." + ) + def build( self, frame_annotation: types.FrameAnnotation, @@ -567,7 +588,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): if bbox_xywh is None and fg_mask_np is not None: bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr) - frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long) + frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float) if frame_annotation.image is not None: image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) @@ -612,7 +633,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): def _load_fg_probability( self, entry: types.FrameAnnotation ) -> Tuple[np.ndarray, str]: - full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore + assert self.dataset_root is not None and entry.mask is not None + full_path = os.path.join(self.dataset_root, entry.mask.path) fg_probability = load_mask(self._local_path(full_path)) if fg_probability.shape[-2:] != entry.image.size: raise ValueError( @@ -647,7 +669,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): fg_probability: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, str, torch.Tensor]: entry_depth = entry.depth - assert entry_depth is not None + assert self.dataset_root is not None and entry_depth is not None path = os.path.join(self.dataset_root, entry_depth.path) depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) @@ -657,6 +679,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): if self.load_depth_masks: assert entry_depth.mask_path is not None + # pyre-ignore mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) depth_mask = load_depth_mask(self._local_path(mask_path)) else: @@ -705,6 +728,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): ) if path.startswith(unwanted_prefix): path = path[len(unwanted_prefix) :] + assert self.dataset_root is not None return os.path.join(self.dataset_root, path) def _local_path(self, path: str) -> str: diff --git a/pytorch3d/implicitron/dataset/orm_types.py b/pytorch3d/implicitron/dataset/orm_types.py new file mode 100644 index 00000000..5736ab4b --- /dev/null +++ b/pytorch3d/implicitron/dataset/orm_types.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This functionality requires SQLAlchemy 2.0 or later. + +import math +import struct +from typing import Optional, Tuple + +import numpy as np + +from pytorch3d.implicitron.dataset.types import ( + DepthAnnotation, + ImageAnnotation, + MaskAnnotation, + PointCloudAnnotation, + VideoAnnotation, + ViewpointAnnotation, +) + +from sqlalchemy import LargeBinary +from sqlalchemy.orm import ( + composite, + DeclarativeBase, + Mapped, + mapped_column, + MappedAsDataclass, +) +from sqlalchemy.types import TypeDecorator + + +# these produce policies to serialize structured types to blobs +def ArrayTypeFactory(shape): + class NumpyArrayType(TypeDecorator): + impl = LargeBinary + + def process_bind_param(self, value, dialect): + if value is not None: + if value.shape != shape: + raise ValueError(f"Passed an array of wrong shape: {value.shape}") + return value.astype(np.float32).tobytes() + return None + + def process_result_value(self, value, dialect): + if value is not None: + return np.frombuffer(value, dtype=np.float32).reshape(shape) + return None + + return NumpyArrayType + + +def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)): + format_symbol = { + float: "f", # float32 + int: "i", # int32 + }[dtype] + + class TupleType(TypeDecorator): + impl = LargeBinary + _format = format_symbol * math.prod(shape) + + def process_bind_param(self, value, _): + if value is None: + return None + + if len(shape) > 1: + value = np.array(value, dtype=dtype).reshape(-1) + + return struct.pack(TupleType._format, *value) + + def process_result_value(self, value, _): + if value is None: + return None + + loaded = struct.unpack(TupleType._format, value) + if len(shape) > 1: + loaded = _rec_totuple( + np.array(loaded, dtype=dtype).reshape(shape).tolist() + ) + + return loaded + + return TupleType + + +def _rec_totuple(t): + if isinstance(t, list): + return tuple(_rec_totuple(x) for x in t) + + return t + + +class Base(MappedAsDataclass, DeclarativeBase): + """subclasses will be converted to dataclasses""" + + +class SqlFrameAnnotation(Base): + __tablename__ = "frame_annots" + + sequence_name: Mapped[str] = mapped_column(primary_key=True) + frame_number: Mapped[int] = mapped_column(primary_key=True) + frame_timestamp: Mapped[float] = mapped_column(index=True) + + image: Mapped[ImageAnnotation] = composite( + mapped_column("_image_path"), + mapped_column("_image_size", TupleTypeFactory(int)), + ) + + depth: Mapped[DepthAnnotation] = composite( + mapped_column("_depth_path", nullable=True), + mapped_column("_depth_scale_adjustment", nullable=True), + mapped_column("_depth_mask_path", nullable=True), + ) + + mask: Mapped[MaskAnnotation] = composite( + mapped_column("_mask_path", nullable=True), + mapped_column("_mask_mass", index=True, nullable=True), + mapped_column( + "_mask_bounding_box_xywh", + TupleTypeFactory(float, shape=(4,)), + nullable=True, + ), + ) + + viewpoint: Mapped[ViewpointAnnotation] = composite( + mapped_column( + "_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True + ), + mapped_column( + "_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True + ), + mapped_column( + "_viewpoint_focal_length", TupleTypeFactory(float), nullable=True + ), + mapped_column( + "_viewpoint_principal_point", TupleTypeFactory(float), nullable=True + ), + mapped_column("_viewpoint_intrinsics_format", nullable=True), + ) + + +class SqlSequenceAnnotation(Base): + __tablename__ = "sequence_annots" + + sequence_name: Mapped[str] = mapped_column(primary_key=True) + category: Mapped[str] = mapped_column(index=True) + + video: Mapped[VideoAnnotation] = composite( + mapped_column("_video_path", nullable=True), + mapped_column("_video_length", nullable=True), + ) + point_cloud: Mapped[PointCloudAnnotation] = composite( + mapped_column("_point_cloud_path", nullable=True), + mapped_column("_point_cloud_quality_score", nullable=True), + mapped_column("_point_cloud_n_points", nullable=True), + ) + # the bigger the better + viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None) diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py new file mode 100644 index 00000000..605e8b52 --- /dev/null +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -0,0 +1,735 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import hashlib +import json +import logging +import os +from dataclasses import dataclass +from typing import ( + Any, + ClassVar, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import numpy as np +import pandas as pd +import sqlalchemy as sa +import torch +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase + +from pytorch3d.implicitron.dataset.frame_data import ( # noqa + FrameData, + FrameDataBuilder, + FrameDataBuilderBase, +) +from pytorch3d.implicitron.tools.config import ( + registry, + ReplaceableBase, + run_auto_creation, +) +from sqlalchemy.orm import Session + +from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation + + +logger = logging.getLogger(__name__) + + +_SET_LISTS_TABLE: str = "set_lists" + + +@registry.register +class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore + """ + A dataset with annotations stored as SQLite tables. This is an index-based dataset. + The length is returned after all sequence and frame filters are applied (see param + definitions below). Indices can either be ordinal in [0, len), or pairs of + (sequence_name, frame_number); with the performance of `dataset[i]` and + `dataset[sequence_name, frame_number]` being same. A faster way to get metadata only + (without blobs) is `dataset.meta[idx]` indexing; it requires box_crop==False. + With ordinal indexing, the sequences are NOT guaranteed to span contiguous index + ranges, and frame numbers are NOT guaranteed to be increasing within a sequence. + Sequence-aware batch samplers have to use `sequence_[frames|indices]_in_order` + iterators, which are efficient. + + This functionality requires SQLAlchemy 2.0 or later. + + Metadata-related args: + sqlite_metadata_file: A SQLite file containing frame and sequence annotation + tables (mapping to SqlFrameAnnotation and SqlSequenceAnnotation, + respectively). + dataset_root: A root directory to look for images, masks, etc. It can be + alternatively set in `frame_data_builder` args, but this takes precedence. + subset_lists_file: A JSON/sqlite file containing the lists of frames + corresponding to different subsets (e.g. train/val/test) of the dataset; + format: {subset: [(sequence_name, frame_id, file_path)]}. All entries + must be present in frame_annotation metadata table. + path_manager: a facade for non-POSIX filesystems. + subsets: Restrict frames/sequences only to the given list of subsets + as defined in subset_lists_file (see above). Applied before all other + filters. + remove_empty_masks: Removes the frames with no active foreground pixels + in the segmentation mask (needs frame_annotation.mask.mass to be set; + null values are retained). + pick_frames_sql_clause: SQL WHERE clause to constrain frame annotations + NOTE: This is a potential security risk! The string is passed to the SQL + engine verbatim. Don’t expose it to end users of your application! + pick_categories: Restrict the dataset to the given list of categories. + pick_sequences: A Sequence of sequence names to restrict the dataset to. + exclude_sequences: A Sequence of the names of the sequences to exclude. + limit_sequences_to: Limit the dataset to the first `limit_sequences_to` + sequences (after other sequence filters have been applied but before + frame-based filters). + limit_to: Limit the dataset to the first #limit_to frames (after other + filters have been applied, except n_frames_per_sequence). + n_frames_per_sequence: If > 0, randomly samples `n_frames_per_sequence` + frames in each sequences uniformly without replacement if it has + more frames than that; applied after other frame-level filters. + seed: The seed of the random generator sampling `n_frames_per_sequence` + random frames per sequence. + """ + + frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation + + sqlite_metadata_file: str = "" + dataset_root: Optional[str] = None + subset_lists_file: str = "" + eval_batches_file: Optional[str] = None + path_manager: Any = None + subsets: Optional[List[str]] = None + remove_empty_masks: bool = True + pick_frames_sql_clause: Optional[str] = None + pick_categories: Tuple[str, ...] = () + + pick_sequences: Tuple[str, ...] = () + exclude_sequences: Tuple[str, ...] = () + limit_sequences_to: int = 0 + limit_to: int = 0 + n_frames_per_sequence: int = -1 + seed: int = 0 + remove_empty_masks_poll_whole_table_threshold: int = 300_000 + # we set it manually in the constructor + # _index: pd.DataFrame = field(init=False) + + frame_data_builder: FrameDataBuilderBase + frame_data_builder_class_type: str = "FrameDataBuilder" + + def __post_init__(self) -> None: + if sa.__version__ < "2.0": + raise ImportError("This class requires SQL Alchemy 2.0 or later") + + if not self.sqlite_metadata_file: + raise ValueError("sqlite_metadata_file must be set") + + if self.dataset_root: + frame_builder_type = self.frame_data_builder_class_type + getattr(self, f"frame_data_builder_{frame_builder_type}_args")[ + "dataset_root" + ] = self.dataset_root + + run_auto_creation(self) + + # pyre-ignore + self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}") + + sequences = self._get_filtered_sequences_if_any() + + if self.subsets: + index = self._build_index_from_subset_lists(sequences) + else: + # TODO: if self.subset_lists_file and not self.subsets, it might be faster to + # still use the concatenated lists, assuming they cover the whole dataset + index = self._build_index_from_db(sequences) + + if self.n_frames_per_sequence >= 0: + index = self._stratified_sample_index(index) + + if len(index) == 0: + raise ValueError(f"There are no frames in the subsets: {self.subsets}!") + + self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore + + self.eval_batches = None # pyre-ignore + if self.eval_batches_file: + self.eval_batches = self._load_filter_eval_batches() + + logger.info(str(self)) + + def __len__(self) -> int: + # pyre-ignore[16] + return len(self._index) + + def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData: + """ + Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair + """ + return self._get_item(frame_idx, True) + + @property + def meta(self): + """ + Allows accessing metadata only without loading blobs using `dataset.meta[idx]`. + Requires box_crop==False, since in that case, cameras cannot be adjusted + without loading masks. + + Returns: + FrameData objects with blob fields like `image_rgb` set to None. + + Raises: + ValueError if dataset.box_crop is set. + """ + return SqlIndexDataset._MetadataAccessor(self) + + @dataclass + class _MetadataAccessor: + dataset: "SqlIndexDataset" + + def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData: + return self.dataset._get_item(frame_idx, False) + + def _get_item( + self, frame_idx: Union[int, Tuple[str, int]], load_blobs: bool = True + ) -> FrameData: + if isinstance(frame_idx, int): + if frame_idx >= len(self._index): + raise IndexError(f"index {frame_idx} out of range {len(self._index)}") + + seq, frame = self._index.index[frame_idx] + else: + seq, frame, *rest = frame_idx + if (seq, frame) not in self._index.index: + raise IndexError( + f"Sequence-frame index {frame_idx} not found; was it filtered out?" + ) + + if rest and rest[0] != self._index.loc[(seq, frame), "_image_path"]: + raise IndexError(f"Non-matching image path in {frame_idx}.") + + stmt = sa.select(self.frame_annotations_type).where( + self.frame_annotations_type.sequence_name == seq, + self.frame_annotations_type.frame_number + == int(frame), # cast from np.int64 + ) + seq_stmt = sa.select(SqlSequenceAnnotation).where( + SqlSequenceAnnotation.sequence_name == seq + ) + with Session(self._sql_engine) as session: + entry = session.scalars(stmt).one() + seq_metadata = session.scalars(seq_stmt).one() + + assert entry.image.path == self._index.loc[(seq, frame), "_image_path"] + + frame_data = self.frame_data_builder.build( + entry, seq_metadata, load_blobs=load_blobs + ) + + # The rest of the fields are optional + frame_data.frame_type = self._get_frame_type(entry) + return frame_data + + def __str__(self) -> str: + # pyre-ignore[16] + return f"SqlIndexDataset #frames={len(self._index)}" + + def sequence_names(self) -> Iterable[str]: + """Returns an iterator over sequence names in the dataset.""" + return self._index.index.unique("sequence_name") + + # override + def category_to_sequence_names(self) -> Dict[str, List[str]]: + stmt = sa.select( + SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name + ).where( # we limit results to sequences that have frames after all filters + SqlSequenceAnnotation.sequence_name.in_(self.sequence_names()) + ) + with self._sql_engine.connect() as connection: + cat_to_seqs = pd.read_sql(stmt, connection) + + return cat_to_seqs.groupby("category")["sequence_name"].apply(list).to_dict() + + # override + def get_frame_numbers_and_timestamps( + self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None + ) -> List[Tuple[int, float]]: + """ + Implements the DatasetBase method. + + NOTE: Avoid this function as there are more efficient alternatives such as + querying `dataset[idx]` directly or getting all sequence frames with + `sequence_[frames|indices]_in_order`. + + Return the index and timestamp in their videos of the frames whose + indices are given in `idxs`. They need to belong to the same sequence! + If timestamps are absent, they are replaced with zeros. + This is used for letting SceneBatchSampler identify consecutive + frames. + + Args: + idxs: a sequence int frame index in the dataset (it can be a slice) + subset_filter: must remain None + + Returns: + list of tuples of + - frame index in video + - timestamp of frame in video, coalesced with 0s + + Raises: + ValueError if idxs belong to more than one sequence. + """ + + if subset_filter is not None: + raise NotImplementedError( + "Subset filters are not supported in SQL Dataset. " + "We encourage creating a dataset per subset." + ) + + index_slice, _ = self._get_frame_no_coalesced_ts_by_row_indices(idxs) + # alternatively, we can use `.values.tolist()`, which may be faster + # but returns a list of lists + return list(index_slice.itertuples()) + + # override + def sequence_frames_in_order( + self, seq_name: str, subset_filter: Optional[Sequence[str]] = None + ) -> Iterator[Tuple[float, int, int]]: + """ + Overrides the default DatasetBase implementation (we don’t use `_seq_to_idx`). + Returns an iterator over the frame indices in a given sequence. + We attempt to first sort by timestamp (if they are available), + then by frame number. + + Args: + seq_name: the name of the sequence. + subset_filter: subset names to filter to + + Returns: + an iterator over triplets `(timestamp, frame_no, dataset_idx)`, + where `frame_no` is the index within the sequence, and + `dataset_idx` is the index within the dataset. + `None` timestamps are replaced with 0s. + """ + # TODO: implement sort_timestamp_first? (which would matter if the orders + # of frame numbers and timestamps are different) + rows = self._index.index.get_loc(seq_name) + if isinstance(rows, slice): + assert rows.stop is not None, "Unexpected result from pandas" + rows = range(rows.start or 0, rows.stop, rows.step or 1) + else: + rows = np.where(rows)[0] + + index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices( + rows, seq_name, subset_filter + ) + index_slice["idx"] = idx + + yield from index_slice.itertuples(index=False) + + # override + def get_eval_batches(self) -> Optional[List[Any]]: + """ + This class does not support eval batches with ordinal indices. You can pass + eval_batches as a batch_sampler to a data_loader since the dataset supports + `dataset[seq_name, frame_no]` indexing. + """ + return self.eval_batches + + # override + def join(self, other_datasets: Iterable[DatasetBase]) -> None: + raise ValueError("Not supported! Preprocess the data by merging them instead.") + + # override + @property + def frame_data_type(self) -> Type[FrameData]: + return self.frame_data_builder.frame_data_type + + def is_filtered(self) -> bool: + """ + Returns `True` in case the dataset has been filtered and thus some frame + annotations stored on the disk might be missing in the dataset object. + Does not account for subsets. + + Returns: + is_filtered: `True` if the dataset has been filtered, else `False`. + """ + return ( + self.remove_empty_masks + or self.limit_to > 0 + or self.limit_sequences_to > 0 + or len(self.pick_sequences) > 0 + or len(self.exclude_sequences) > 0 + or len(self.pick_categories) > 0 + or self.n_frames_per_sequence > 0 + ) + + def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]: + # maximum possible query: WHERE category IN 'self.pick_categories' + # AND sequence_name IN 'self.pick_sequences' + # AND sequence_name NOT IN 'self.exclude_sequences' + # LIMIT 'self.limit_sequence_to' + + stmt = sa.select(SqlSequenceAnnotation.sequence_name) + + where_conditions = [ + *self._get_category_filters(), + *self._get_pick_filters(), + *self._get_exclude_filters(), + ] + if where_conditions: + stmt = stmt.where(*where_conditions) + + if self.limit_sequences_to > 0: + logger.info( + f"Limiting dataset to first {self.limit_sequences_to} sequences" + ) + # NOTE: ROWID is SQLite-specific + stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to) + + if not where_conditions and self.limit_sequences_to <= 0: + # we will not need to filter by sequences + return None + + with self._sql_engine.connect() as connection: + sequences = pd.read_sql_query(stmt, connection)["sequence_name"] + logger.info("... retained %d sequences" % len(sequences)) + + return sequences + + def _get_category_filters(self) -> List[sa.ColumnElement]: + if not self.pick_categories: + return [] + + logger.info(f"Limiting dataset to categories: {self.pick_categories}") + return [SqlSequenceAnnotation.category.in_(self.pick_categories)] + + def _get_pick_filters(self) -> List[sa.ColumnElement]: + if not self.pick_sequences: + return [] + + logger.info(f"Limiting dataset to sequences: {self.pick_sequences}") + return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)] + + def _get_exclude_filters(self) -> List[sa.ColumnOperators]: + if not self.exclude_sequences: + return [] + + logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}") + return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)] + + def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame: + assert self.subsets is not None + with open(subset_lists_path, "r") as f: + subset_to_seq_frame = json.load(f) + + seq_frame_list = sum( + ( + [(*row, subset) for row in subset_to_seq_frame[subset]] + for subset in self.subsets + ), + [], + ) + index = pd.DataFrame( + seq_frame_list, + columns=["sequence_name", "frame_number", "_image_path", "subset"], + ) + return index + + def _load_subsets_from_sql(self, subset_lists_path: str) -> pd.DataFrame: + subsets = self.subsets + assert subsets is not None + # we need a new engine since we store the subsets in a separate DB + engine = sa.create_engine(f"sqlite:///{subset_lists_path}") + table = sa.Table(_SET_LISTS_TABLE, sa.MetaData(), autoload_with=engine) + stmt = sa.select(table).where(table.c.subset.in_(subsets)) + with engine.connect() as connection: + index = pd.read_sql(stmt, connection) + + return index + + def _build_index_from_subset_lists( + self, sequences: Optional[pd.Series] + ) -> pd.DataFrame: + if not self.subset_lists_file: + raise ValueError("Requested subsets but subset_lists_file not given") + + logger.info(f"Loading subset lists from {self.subset_lists_file}.") + + subset_lists_path = self._local_path(self.subset_lists_file) + if subset_lists_path.lower().endswith(".json"): + index = self._load_subsets_from_json(subset_lists_path) + else: + index = self._load_subsets_from_sql(subset_lists_path) + index = index.set_index(["sequence_name", "frame_number"]) + logger.info(f" -> loaded {len(index)} samples of {self.subsets}.") + + if sequences is not None: + logger.info("Applying filtered sequences.") + sequence_values = index.index.get_level_values("sequence_name") + index = index.loc[sequence_values.isin(sequences)] + logger.info(f" -> retained {len(index)} samples.") + + pick_frames_criteria = [] + if self.remove_empty_masks: + logger.info("Culling samples with empty masks.") + + if len(index) > self.remove_empty_masks_poll_whole_table_threshold: + # APPROACH 1: find empty masks and drop indices. + # dev load: 17s / 15 s (3.1M / 500K) + stmt = sa.select( + self.frame_annotations_type.sequence_name, + self.frame_annotations_type.frame_number, + ).where(self.frame_annotations_type._mask_mass == 0) + with Session(self._sql_engine) as session: + to_remove = session.execute(stmt).all() + + # Pandas uses np.int64 for integer types, so we have to case + # we might want to read it to pandas DataFrame directly to avoid the loop + to_remove = [(seq, np.int64(fr)) for seq, fr in to_remove] + index.drop(to_remove, errors="ignore", inplace=True) + else: + # APPROACH 3: load index into a temp table and join with annotations + # dev load: 94 s / 23 s (3.1M / 500K) + pick_frames_criteria.append( + sa.or_( + self.frame_annotations_type._mask_mass.is_(None), + self.frame_annotations_type._mask_mass != 0, + ) + ) + + if self.pick_frames_sql_clause: + logger.info("Applying the custom SQL clause.") + pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause)) + + if pick_frames_criteria: + index = self._pick_frames_by_criteria(index, pick_frames_criteria) + + logger.info(f" -> retained {len(index)} samples.") + + if self.limit_to > 0: + logger.info(f"Limiting dataset to first {self.limit_to} frames") + index = index.sort_index().iloc[: self.limit_to] + + return index.reset_index() + + def _pick_frames_by_criteria(self, index: pd.DataFrame, criteria) -> pd.DataFrame: + IndexTable = self._get_temp_index_table_instance() + with self._sql_engine.connect() as connection: + IndexTable.create(connection) + # we don’t let pandas’s `to_sql` create the table automatically as + # the table would be permanent, so we create it and append with pandas + n_rows = index.to_sql(IndexTable.name, connection, if_exists="append") + assert n_rows == len(index) + sa_type = self.frame_annotations_type + stmt = ( + sa.select(IndexTable) + .select_from( + IndexTable.join( + self.frame_annotations_type, + sa.and_( + sa_type.sequence_name == IndexTable.c.sequence_name, + sa_type.frame_number == IndexTable.c.frame_number, + ), + ) + ) + .where(*criteria) + ) + return pd.read_sql_query(stmt, connection).set_index( + ["sequence_name", "frame_number"] + ) + + def _build_index_from_db(self, sequences: Optional[pd.Series]): + logger.info("Loading sequcence-frame index from the database") + stmt = sa.select( + self.frame_annotations_type.sequence_name, + self.frame_annotations_type.frame_number, + self.frame_annotations_type._image_path, + sa.null().label("subset"), + ) + where_conditions = [] + if sequences is not None: + logger.info(" applying filtered sequences") + where_conditions.append( + self.frame_annotations_type.sequence_name.in_(sequences.tolist()) + ) + + if self.remove_empty_masks: + logger.info(" excluding samples with empty masks") + where_conditions.append( + sa.or_( + self.frame_annotations_type._mask_mass.is_(None), + self.frame_annotations_type._mask_mass != 0, + ) + ) + + if self.pick_frames_sql_clause: + logger.info(" applying custom SQL clause") + where_conditions.append(sa.text(self.pick_frames_sql_clause)) + + if where_conditions: + stmt = stmt.where(*where_conditions) + + if self.limit_to > 0: + logger.info(f"Limiting dataset to first {self.limit_to} frames") + stmt = stmt.order_by( + self.frame_annotations_type.sequence_name, + self.frame_annotations_type.frame_number, + ).limit(self.limit_to) + + with self._sql_engine.connect() as connection: + index = pd.read_sql_query(stmt, connection) + + logger.info(f" -> loaded {len(index)} samples.") + return index + + def _sort_index_(self, index): + logger.info("Sorting the index by sequence and frame number.") + index.sort_values(["sequence_name", "frame_number"], inplace=True) + logger.info(" -> Done.") + + def _load_filter_eval_batches(self): + assert self.eval_batches_file + logger.info(f"Loading eval batches from {self.eval_batches_file}") + + if not os.path.isfile(self.eval_batches_file): + # The batch indices file does not exist. + # Most probably the user has not specified the root folder. + raise ValueError( + f"Looking for dataset json file in {self.eval_batches_file}. " + + "Please specify a correct dataset_root folder." + ) + + with open(self.eval_batches_file, "r") as f: + eval_batches = json.load(f) + + # limit the dataset to sequences to allow multiple evaluations in one file + pick_sequences = set(self.pick_sequences) + if self.pick_categories: + cat_to_seq = self.category_to_sequence_names() + pick_sequences.update( + seq for cat in self.pick_categories for seq in cat_to_seq[cat] + ) + + if pick_sequences: + old_len = len(eval_batches) + eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences] + logger.warn( + f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}" + ) + + if self.exclude_sequences: + old_len = len(eval_batches) + exclude_sequences = set(self.exclude_sequences) + eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences] + logger.warn( + f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}" + ) + + return eval_batches + + def _stratified_sample_index(self, index): + # NOTE this stratified sampling can be done more efficiently in + # the no-subset case above if it is added to the SQL query. + # We keep this generic implementation since no-subset case is uncommon + index = index.groupby("sequence_name", group_keys=False).apply( + lambda seq_frames: seq_frames.sample( + min(len(seq_frames), self.n_frames_per_sequence), + random_state=( + _seq_name_to_seed(seq_frames.iloc[0]["sequence_name"]) + self.seed + ), + ) + ) + logger.info(f" -> retained {len(index)} samples aster stratified sampling.") + return index + + def _get_frame_type(self, entry: SqlFrameAnnotation) -> Optional[str]: + return self._index.loc[(entry.sequence_name, entry.frame_number), "subset"] + + def _get_frame_no_coalesced_ts_by_row_indices( + self, + idxs: Sequence[int], + seq_name: Optional[str] = None, + subset_filter: Union[Sequence[str], str, None] = None, + ) -> Tuple[pd.DataFrame, Sequence[int]]: + """ + Loads timestamps for given index rows belonging to the same sequence. + If seq_name is known, it speeds up the computation. + Raises ValueError if `idxs` do not all belong to a single sequences . + """ + index_slice = self._index.iloc[idxs] + if subset_filter is not None: + if isinstance(subset_filter, str): + subset_filter = [subset_filter] + indicator = index_slice["subset"].isin(subset_filter) + index_slice = index_slice.loc[indicator] + idxs = [i for i, isin in zip(idxs, indicator) if isin] + + frames = index_slice.index.get_level_values("frame_number").tolist() + if seq_name is None: + seq_name_list = index_slice.index.get_level_values("sequence_name").tolist() + seq_name_set = set(seq_name_list) + if len(seq_name_set) > 1: + raise ValueError("Given indices belong to more than one sequence.") + elif len(seq_name_set) == 1: + seq_name = seq_name_list[0] + + coalesced_ts = sa.sql.functions.coalesce( + self.frame_annotations_type.frame_timestamp, 0 + ) + stmt = sa.select( + coalesced_ts.label("frame_timestamp"), + self.frame_annotations_type.frame_number, + ).where( + self.frame_annotations_type.sequence_name == seq_name, + self.frame_annotations_type.frame_number.in_(frames), + ) + + with self._sql_engine.connect() as connection: + frame_no_ts = pd.read_sql_query(stmt, connection) + + if len(frame_no_ts) != len(index_slice): + raise ValueError( + "Not all indices are found in the database; " + "do they belong to more than one sequence?" + ) + + return frame_no_ts, idxs + + def _local_path(self, path: str) -> str: + if self.path_manager is None: + return path + return self.path_manager.get_local_path(path) + + def _get_temp_index_table_instance(self, table_name: str = "__index"): + CachedTable = self.frame_annotations_type.metadata.tables.get(table_name) + if CachedTable is not None: # table definition is not idempotent + return CachedTable + + return sa.Table( + table_name, + self.frame_annotations_type.metadata, + sa.Column("sequence_name", sa.String, primary_key=True), + sa.Column("frame_number", sa.Integer, primary_key=True), + sa.Column("_image_path", sa.String), + sa.Column("subset", sa.String), + prefixes=["TEMP"], # NOTE SQLite specific! + ) + + +def _seq_name_to_seed(seq_name) -> int: + """Generates numbers in [0, 2 ** 28)""" + return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16) + + +def _safe_as_tensor(data, dtype): + return torch.tensor(data, dtype=dtype) if data is not None else None diff --git a/pytorch3d/implicitron/dataset/sql_dataset_provider.py b/pytorch3d/implicitron/dataset/sql_dataset_provider.py new file mode 100644 index 00000000..ab161e8d --- /dev/null +++ b/pytorch3d/implicitron/dataset/sql_dataset_provider.py @@ -0,0 +1,424 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +from typing import List, Optional, Tuple, Type + +import numpy as np + +from omegaconf import DictConfig, OmegaConf + +from pytorch3d.implicitron.dataset.dataset_map_provider import ( + DatasetMap, + DatasetMapProviderBase, + PathManagerFactory, +) +from pytorch3d.implicitron.tools.config import ( + expand_args_fields, + registry, + run_auto_creation, +) + +from .sql_dataset import SqlIndexDataset + + +_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "") + +# _NEED_CONTROL is a list of those elements of SqlIndexDataset which +# are not directly specified for it in the config but come from the +# DatasetMapProvider. +_NEED_CONTROL: Tuple[str, ...] = ( + "path_manager", + "subsets", + "sqlite_metadata_file", + "subset_lists_file", +) + +logger = logging.getLogger(__name__) + + +@registry.register +class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] + """ + Generates the training, validation, and testing dataset objects for + a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base. + + The dataset is organized in the filesystem as follows:: + + self.dataset_root + ├── + │ ├── + │ │ ├── depth_masks + │ │ ├── depths + │ │ ├── images + │ │ ├── masks + │ │ └── pointcloud.ply + │ ├── + │ │ ├── depth_masks + │ │ ├── depths + │ │ ├── images + │ │ ├── masks + │ │ └── pointcloud.ply + │ ├── ... + │ ├── + │ ├── set_lists + │ ├── .json + │ ├── .json + │ ├── ... + │ ├── .json + │ ├── eval_batches + │ │ ├── .json + │ │ ├── .json + │ │ ├── ... + │ │ ├── .json + │ ├── frame_annotations.jgz + │ ├── sequence_annotations.jgz + ├── + ├── ... + ├── + ├── set_lists + ├── .sqlite + ├── .sqlite + ├── ... + ├── .sqlite + ├── eval_batches + │ ├── .json + │ ├── .json + │ ├── ... + │ ├── .json + + The dataset contains sequences named `` that may be partitioned by + directories such as `` e.g. representing categories but they + can also be stored in a flat structure. Each sequence folder contains the list of + sequence images, depth maps, foreground masks, and valid-depth masks + `images`, `depths`, `masks`, and `depth_masks` respectively. Furthermore, + `set_lists/` dirtectories (with partitions or global) store json or sqlite files + `.`, each describing a certain sequence subset. + These subset path conventions are not hard-coded and arbitrary relative path can be + specified by setting `self.subset_lists_path` to the relative path w.r.t. + dataset root. + + Each `.json` file contains the following dictionary:: + + { + "train": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + "val": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + "test": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + ] + + defining the list of frames (identified with their `sequence_name` and + `frame_number`) in the "train", "val", and "test" subsets of the dataset. In case of + SQLite format, `.sqlite` contains a table with the header:: + + | sequence_name | frame_number | image_path | subset | + + Note that `frame_number` can be obtained only from the metadata and + does not necesarrily correspond to the numeric suffix of the corresponding image + file name (e.g. a file `//images/frame00005.jpg` can + have its frame number set to `20`, not 5). + + Each `.json` file contains a list of evaluation examples + in the following form:: + + [ + [ # batch 1 + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + [ # batch 2 + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + ] + + Note that the evaluation examples always come from the `"test"` subset of the dataset. + (test frames can repeat across batches). The batches can contain single element, + which is typical in case of regular radiance field fitting. + + Args: + subset_lists_path: The relative path to the dataset subset definition. + For CO3D, these include e.g. "skateboard/set_lists/set_lists_manyview_dev_0.json". + By default (None), dataset is not partitioned to subsets (in that case, setting + `ignore_subsets` will speed up construction) + dataset_root: The root folder of the dataset. + metadata_basename: name of the SQL metadata file in dataset_root; + not expected to be changed by users + test_on_train: Construct validation and test datasets from + the training subset; note that in practice, in this + case all subset dataset objects will be same + only_test_set: Load only the test set. Incompatible with `test_on_train`. + ignore_subsets: Don’t filter by subsets in the dataset; note that in this + case all subset datasets will be same + eval_batch_num_training_frames: Add a certain number of training frames to each + eval batch. Useful for evaluating models that require + source views as input (e.g. NeRF-WCE / PixelNeRF). + dataset_args: Specifies additional arguments to the + JsonIndexDataset constructor call. + path_manager_factory: (Optional) An object that generates an instance of + PathManager that can translate provided file paths. + path_manager_factory_class_type: The class type of `path_manager_factory`. + """ + + category: Optional[str] = None + subset_list_name: Optional[str] = None # TODO: docs + # OR + subset_lists_path: Optional[str] = None + eval_batches_path: Optional[str] = None + + dataset_root: str = _CO3D_SQL_DATASET_ROOT + metadata_basename: str = "metadata.sqlite" + + test_on_train: bool = False + only_test_set: bool = False + ignore_subsets: bool = False + train_subsets: Tuple[str, ...] = ("train",) + val_subsets: Tuple[str, ...] = ("val",) + test_subsets: Tuple[str, ...] = ("test",) + + eval_batch_num_training_frames: int = 0 + + # this is a mould that is never constructed, used to build self._dataset_map values + dataset_class_type: str = "SqlIndexDataset" + dataset: SqlIndexDataset + + path_manager_factory: PathManagerFactory + path_manager_factory_class_type: str = "PathManagerFactory" + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + + if self.only_test_set and self.test_on_train: + raise ValueError("Cannot have only_test_set and test_on_train") + + if self.ignore_subsets and not self.only_test_set: + self.test_on_train = True # no point in loading same data 3 times + + path_manager = self.path_manager_factory.get() + + sqlite_metadata_file = os.path.join(self.dataset_root, self.metadata_basename) + sqlite_metadata_file = _local_path(path_manager, sqlite_metadata_file) + + if not os.path.isfile(sqlite_metadata_file): + # The sqlite_metadata_file does not exist. + # Most probably the user has not specified the root folder. + raise ValueError( + f"Looking for frame annotations in {sqlite_metadata_file}." + + " Please specify a correct dataset_root folder." + + " Note: By default the root folder is taken from the" + + " CO3D_SQL_DATASET_ROOT environment variable." + ) + + if self.subset_lists_path and self.subset_list_name: + raise ValueError( + "subset_lists_path and subset_list_name cannot be both set" + ) + + subset_lists_file = self._get_lists_file("set_lists") + + # setup the common dataset arguments + common_dataset_kwargs = { + **getattr(self, f"dataset_{self.dataset_class_type}_args"), + "sqlite_metadata_file": sqlite_metadata_file, + "dataset_root": self.dataset_root, + "subset_lists_file": subset_lists_file, + "path_manager": path_manager, + } + + if self.category: + logger.info(f"Forcing category filter in the datasets to {self.category}") + common_dataset_kwargs["pick_categories"] = self.category.split(",") + + # get the used dataset type + dataset_type: Type[SqlIndexDataset] = registry.get( + SqlIndexDataset, self.dataset_class_type + ) + expand_args_fields(dataset_type) + + if subset_lists_file is not None and not os.path.isfile(subset_lists_file): + available_subsets = self._get_available_subsets( + OmegaConf.to_object(common_dataset_kwargs["pick_categories"]) + ) + msg = f"Cannot find subset list file {self.subset_lists_path}." + if available_subsets: + msg += f" Some of the available subsets: {str(available_subsets)}." + raise ValueError(msg) + + train_dataset = None + val_dataset = None + if not self.only_test_set: + # load the training set + logger.debug("Constructing train dataset.") + train_dataset = dataset_type( + **common_dataset_kwargs, subsets=self._get_subsets(self.train_subsets) + ) + logger.info(f"Train dataset: {str(train_dataset)}") + + if self.test_on_train: + assert train_dataset is not None + val_dataset = test_dataset = train_dataset + else: + # load the val and test sets + if not self.only_test_set: + # NOTE: this is always loaded in JsonProviderV2 + logger.debug("Extracting val dataset.") + val_dataset = dataset_type( + **common_dataset_kwargs, subsets=self._get_subsets(self.val_subsets) + ) + logger.info(f"Val dataset: {str(val_dataset)}") + + logger.debug("Extracting test dataset.") + eval_batches_file = self._get_lists_file("eval_batches") + del common_dataset_kwargs["eval_batches_file"] + test_dataset = dataset_type( + **common_dataset_kwargs, + subsets=self._get_subsets(self.test_subsets, True), + eval_batches_file=eval_batches_file, + ) + logger.info(f"Test dataset: {str(test_dataset)}") + + if ( + eval_batches_file is not None + and self.eval_batch_num_training_frames > 0 + ): + self._extend_eval_batches(test_dataset) + + self._dataset_map = DatasetMap( + train=train_dataset, val=val_dataset, test=test_dataset + ) + + def _get_subsets(self, subsets, is_eval: bool = False): + if self.ignore_subsets: + return None + + if is_eval and self.eval_batch_num_training_frames > 0: + # we will need to have training frames for extended batches + return list(subsets) + list(self.train_subsets) + + return subsets + + def _extend_eval_batches(self, test_dataset: SqlIndexDataset) -> None: + rng = np.random.default_rng(seed=0) + eval_batches = test_dataset.get_eval_batches() + if eval_batches is None: + raise ValueError("Eval batches were not loaded!") + + for batch in eval_batches: + sequence = batch[0][0] + seq_frames = list( + test_dataset.sequence_frames_in_order(sequence, self.train_subsets) + ) + idx_to_add = rng.permutation(len(seq_frames))[ + : self.eval_batch_num_training_frames + ] + batch.extend((sequence, seq_frames[a][1]) for a in idx_to_add) + + @classmethod + def dataset_tweak_args(cls, type, args: DictConfig) -> None: + """ + Called by get_default_args. + Certain fields are not exposed on each dataset class + but rather are controlled by this provider class. + """ + for key in _NEED_CONTROL: + del args[key] + + def create_dataset(self): + # No `dataset` member of this class is created. + # The dataset(s) live in `self.get_dataset_map`. + pass + + def get_dataset_map(self) -> DatasetMap: + return self._dataset_map # pyre-ignore [16] + + def _get_available_subsets(self, categories: List[str]): + """ + Get the available subset names for a given category folder (if given) inside + a root dataset folder `dataset_root`. + """ + path_manager = self.path_manager_factory.get() + + subsets: List[str] = [] + for prefix in [""] + categories: + set_list_dir = os.path.join(self.dataset_root, prefix, "set_lists") + if not ( + (path_manager is not None) and path_manager.isdir(set_list_dir) + ) and not os.path.isdir(set_list_dir): + continue + + set_list_files = (os.listdir if path_manager is None else path_manager.ls)( + set_list_dir + ) + subsets.extend(os.path.join(prefix, "set_lists", f) for f in set_list_files) + + return subsets + + def _get_lists_file(self, flavor: str) -> Optional[str]: + if flavor == "eval_batches": + subset_lists_path = self.eval_batches_path + else: + subset_lists_path = self.subset_lists_path + + if not subset_lists_path and not self.subset_list_name: + return None + + category_elem = "" + if self.category and "," not in self.category: + # if multiple categories are given, looking for global set lists + category_elem = self.category + + subset_lists_path = subset_lists_path or ( + os.path.join( + category_elem, f"{flavor}", f"{flavor}_{self.subset_list_name}" + ) + ) + + assert subset_lists_path + path_manager = self.path_manager_factory.get() + # try absolute path first + subset_lists_file = _get_local_path_check_extensions( + subset_lists_path, path_manager + ) + if subset_lists_file: + return subset_lists_file + + full_path = os.path.join(self.dataset_root, subset_lists_path) + subset_lists_file = _get_local_path_check_extensions(full_path, path_manager) + + if not subset_lists_file: + raise FileNotFoundError( + f"Subset lists path given but not found: {full_path}" + ) + + return subset_lists_file + + +def _get_local_path_check_extensions( + path, path_manager, extensions=("", ".sqlite", ".json") +) -> Optional[str]: + for ext in extensions: + local = _local_path(path_manager, path + ext) + if os.path.isfile(local): + return local + + return None + + +def _local_path(path_manager, path: str) -> str: + if path_manager is None: + return path + return path_manager.get_local_path(path) diff --git a/pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py b/pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py new file mode 100644 index 00000000..4640feb2 --- /dev/null +++ b/pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any, Dict, Optional, Tuple + +from pytorch3d.implicitron.dataset.data_loader_map_provider import ( + DataLoaderMap, + SceneBatchSampler, + SequenceDataLoaderMapProvider, +) +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase +from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap +from pytorch3d.implicitron.dataset.frame_data import FrameData +from pytorch3d.implicitron.tools.config import registry, run_auto_creation + +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + +# TODO: we can merge it with SequenceDataLoaderMapProvider in PyTorch3D +# and support both eval_batches protocols +@registry.register +class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider): + """ + Implementation of DataLoaderMapProviderBase that may use internal eval batches for + the test dataset. In particular, if `eval_batches_relpath` is set, it loads + eval batches from that json file, otherwise test set is treated in the same way as + train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type` + are respected. + + If conditioning is not required, then the batch size should + be set as 1, and most of the fields do not matter. + + If conditioning is required, each batch will contain one main + frame first to predict and the, rest of the elements are for + conditioning. + + If images_per_seq_options is left empty, the conditioning + frames are picked according to the conditioning type given. + This does not have regard to the order of frames in a + scene, or which frames belong to what scene. + + If images_per_seq_options is given, then the conditioning types + must be SAME and the remaining fields are used. + + Members: + batch_size: The size of the batch of the data loader. + num_workers: Number of data-loading threads in each data loader. + dataset_length_train: The number of batches in a training epoch. Or 0 to mean + an epoch is the length of the training set. + dataset_length_val: The number of batches in a validation epoch. Or 0 to mean + an epoch is the length of the validation set. + dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of + batches in a testing epoch. Or 0 to mean an epoch is the length of the test + set. + images_per_seq_options: Possible numbers of frames sampled per sequence in a batch. + If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial + value. Empty (the default) means that we are not careful about which frames + come from which scene. + sample_consecutive_frames: if True, will sample a contiguous interval of frames + in the sequence. It first sorts the frames by timestimps when available, + otherwise by frame numbers, finds the connected segments within the sequence + of sufficient length, then samples a random pivot element among them and + ideally uses it as a middle of the temporal window, shifting the borders + where necessary. This strategy mitigates the bias against shorter segments + and their boundaries. + consecutive_frames_max_gap: if a number > 0, then used to define the maximum + difference in frame_number of neighbouring frames when forming connected + segments; if both this and consecutive_frames_max_gap_seconds are 0s, + the whole sequence is considered a segment regardless of frame numbers. + consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the + maximum difference in frame_timestamp of neighbouring frames when forming + connected segments; if both this and consecutive_frames_max_gap are 0s, + the whole sequence is considered a segment regardless of frame timestamps. + """ + + batch_size: int = 1 + num_workers: int = 0 + + dataset_length_train: int = 0 + dataset_length_val: int = 0 + dataset_length_test: int = 0 + + images_per_seq_options: Tuple[int, ...] = () + sample_consecutive_frames: bool = False + consecutive_frames_max_gap: int = 0 + consecutive_frames_max_gap_seconds: float = 0.1 + + def __post_init__(self): + run_auto_creation(self) + + def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap: + """ + Returns a collection of data loaders for a given collection of datasets. + """ + train = self._make_generic_data_loader( + datasets.train, + self.dataset_length_train, + datasets.train, + ) + + val = self._make_generic_data_loader( + datasets.val, + self.dataset_length_val, + datasets.train, + ) + + if datasets.test is not None and datasets.test.get_eval_batches() is not None: + test = self._make_eval_data_loader(datasets.test) + else: + test = self._make_generic_data_loader( + datasets.test, + self.dataset_length_test, + datasets.train, + ) + + return DataLoaderMap(train=train, val=val, test=test) + + def _make_eval_data_loader( + self, + dataset: Optional[DatasetBase], + ) -> Optional[DataLoader[FrameData]]: + if dataset is None: + return None + + return DataLoader( + dataset, + batch_sampler=dataset.get_eval_batches(), + **self._get_data_loader_common_kwargs(dataset), + ) + + def _make_generic_data_loader( + self, + dataset: Optional[DatasetBase], + num_batches: int, + train_dataset: Optional[DatasetBase], + ) -> Optional[DataLoader[FrameData]]: + """ + Returns the dataloader for a dataset. + + Args: + dataset: the dataset + num_batches: possible ceiling on number of batches per epoch + train_dataset: the training dataset, used if conditioning_type==TRAIN + conditioning_type: source for padding of batches + """ + if dataset is None: + return None + + data_loader_kwargs = self._get_data_loader_common_kwargs(dataset) + + if len(self.images_per_seq_options) > 0: + # this is a typical few-view setup + # conditioning comes from the same subset since subsets are split by seqs + batch_sampler = SceneBatchSampler( + dataset, + self.batch_size, + num_batches=len(dataset) if num_batches <= 0 else num_batches, + images_per_seq_options=self.images_per_seq_options, + sample_consecutive_frames=self.sample_consecutive_frames, + consecutive_frames_max_gap=self.consecutive_frames_max_gap, + consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds, + ) + return DataLoader( + dataset, + batch_sampler=batch_sampler, + **data_loader_kwargs, + ) + + if self.batch_size == 1: + # this is a typical many-view setup (without conditioning) + return self._simple_loader(dataset, num_batches, data_loader_kwargs) + + # edge case: conditioning on train subset, typical for Nerformer-like many-view + # there is only one sequence in all datasets, so we condition on another subset + return self._train_loader( + dataset, train_dataset, num_batches, data_loader_kwargs + ) + + def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]: + return { + "num_workers": self.num_workers, + "collate_fn": dataset.frame_data_type.collate, + } diff --git a/setup.py b/setup.py index 357073ab..54e6283f 100755 --- a/setup.py +++ b/setup.py @@ -164,6 +164,7 @@ setup( "tqdm>4.29.0", "matplotlib", "accelerate", + "sqlalchemy>=2.0", ], }, entry_points={ diff --git a/tests/implicitron/data/sql_dataset/set_lists_100.json b/tests/implicitron/data/sql_dataset/set_lists_100.json new file mode 100644 index 00000000..96dbe2b4 --- /dev/null +++ b/tests/implicitron/data/sql_dataset/set_lists_100.json @@ -0,0 +1 @@ +{"train": [["cat0_seq0", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000000.jpg"], ["cat0_seq0", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000002.jpg"], ["cat0_seq0", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000004.jpg"], ["cat0_seq0", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000006.jpg"], ["cat0_seq0", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000008.jpg"], ["cat0_seq1", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000000.jpg"], ["cat0_seq1", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000002.jpg"], ["cat0_seq1", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000004.jpg"], ["cat0_seq1", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000006.jpg"], ["cat0_seq1", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000008.jpg"], ["cat0_seq2", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000000.jpg"], ["cat0_seq2", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000002.jpg"], ["cat0_seq2", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000004.jpg"], ["cat0_seq2", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000006.jpg"], ["cat0_seq2", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000008.jpg"], ["cat0_seq3", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000000.jpg"], ["cat0_seq3", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000002.jpg"], ["cat0_seq3", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000004.jpg"], ["cat0_seq3", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000006.jpg"], ["cat0_seq3", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000008.jpg"], ["cat0_seq4", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000000.jpg"], ["cat0_seq4", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000002.jpg"], ["cat0_seq4", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000004.jpg"], ["cat0_seq4", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000006.jpg"], ["cat0_seq4", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000008.jpg"], ["cat1_seq0", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000000.jpg"], ["cat1_seq0", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000002.jpg"], ["cat1_seq0", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000004.jpg"], ["cat1_seq0", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000006.jpg"], ["cat1_seq0", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000008.jpg"], ["cat1_seq1", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000000.jpg"], ["cat1_seq1", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000002.jpg"], ["cat1_seq1", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000004.jpg"], ["cat1_seq1", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000006.jpg"], ["cat1_seq1", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000008.jpg"], ["cat1_seq2", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000000.jpg"], ["cat1_seq2", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000002.jpg"], ["cat1_seq2", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000004.jpg"], ["cat1_seq2", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000006.jpg"], ["cat1_seq2", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000008.jpg"], ["cat1_seq3", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000000.jpg"], ["cat1_seq3", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000002.jpg"], ["cat1_seq3", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000004.jpg"], ["cat1_seq3", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000006.jpg"], ["cat1_seq3", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000008.jpg"], ["cat1_seq4", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000000.jpg"], ["cat1_seq4", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000002.jpg"], ["cat1_seq4", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000004.jpg"], ["cat1_seq4", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000006.jpg"], ["cat1_seq4", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000008.jpg"]], "test": [["cat0_seq0", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000001.jpg"], ["cat0_seq0", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000003.jpg"], ["cat0_seq0", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000005.jpg"], ["cat0_seq0", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000007.jpg"], ["cat0_seq0", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000009.jpg"], ["cat0_seq1", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000001.jpg"], ["cat0_seq1", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000003.jpg"], ["cat0_seq1", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000005.jpg"], ["cat0_seq1", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000007.jpg"], ["cat0_seq1", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000009.jpg"], ["cat0_seq2", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000001.jpg"], ["cat0_seq2", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000003.jpg"], ["cat0_seq2", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000005.jpg"], ["cat0_seq2", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000007.jpg"], ["cat0_seq2", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000009.jpg"], ["cat0_seq3", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000001.jpg"], ["cat0_seq3", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000003.jpg"], ["cat0_seq3", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000005.jpg"], ["cat0_seq3", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000007.jpg"], ["cat0_seq3", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000009.jpg"], ["cat0_seq4", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000001.jpg"], ["cat0_seq4", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000003.jpg"], ["cat0_seq4", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000005.jpg"], ["cat0_seq4", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000007.jpg"], ["cat0_seq4", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000009.jpg"], ["cat1_seq0", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000001.jpg"], ["cat1_seq0", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000003.jpg"], ["cat1_seq0", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000005.jpg"], ["cat1_seq0", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000007.jpg"], ["cat1_seq0", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000009.jpg"], ["cat1_seq1", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000001.jpg"], ["cat1_seq1", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000003.jpg"], ["cat1_seq1", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000005.jpg"], ["cat1_seq1", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000007.jpg"], ["cat1_seq1", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000009.jpg"], ["cat1_seq2", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000001.jpg"], ["cat1_seq2", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000003.jpg"], ["cat1_seq2", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000005.jpg"], ["cat1_seq2", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000007.jpg"], ["cat1_seq2", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000009.jpg"], ["cat1_seq3", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000001.jpg"], ["cat1_seq3", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000003.jpg"], ["cat1_seq3", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000005.jpg"], ["cat1_seq3", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000007.jpg"], ["cat1_seq3", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000009.jpg"], ["cat1_seq4", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000001.jpg"], ["cat1_seq4", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000003.jpg"], ["cat1_seq4", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000005.jpg"], ["cat1_seq4", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000007.jpg"], ["cat1_seq4", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000009.jpg"]]} \ No newline at end of file diff --git a/tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite b/tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..2d8ea3b2bb10d14175b1f8d87a4cebffa0210410 GIT binary patch literal 81920 zcmeI5cT^PD`~R0+%CZ%Wh}c*Qih@Y9hP_unPys>30-_?)71*WOVoU_1Mq@Xz#1=JC zW8%ja``)X@8cVR3*t^CUOJXz{{VsRMVTLiB`JHp;^ZVnodyeZqyUY8TxvzU?pZDyY zy9qsGbBZ$znWn$Ye4X>%Usoe{XM2)(4Jn1MB||oX7w6&_(9o zFIuaOlR3LQkhwh2W6eIbi)Mmq(!c$I{ds!}*jvEf0`?ZLw}8C`=(fNoo(?V^-rf$M zmK3KB$<4?tOwG$kNzKnU6&GZ$n(4*jPV<%GLH4`6YQnG71gd;(AARiG0@^igWTZii%V7%!a7g z_%Q3Ab0ku7@=~)hQp~Bv+3!rjyk3+uI>Qhi8z0X9zqH&G+DPkB{w5K%i2Q17O2c)F3ibK%Q0KuO!8LyUh8&Q-(8fGR+PfjgF?Kv z-5z2p$xqM8&q^6$8ksV3)Cgu9I!8u@^@*+Q2|aps>mJrC$q*BnWbom296zRC*?LI} zRf`XuT|6Ru9maX)Z$(>_j&XGH@bz{0q%|2DRo+-7bm9$=Dldu~BX1MGyLS)INzX9xdK|xR*#DSsU*D`74EwVN>X!kibtgsrI`vd-Zc`qZ_c;%rJ}ckyP|Jb_PBRQRlMQf z`+}9e(aG7Pr>}#Iz%MGWZi;nemK0k@inTKcx#lu?qm-=V=$NC5QJNoPTOZs#Allo)&L$((y-eAAkTfp7|_76@%pRF&`7wOaVz4ej$ zHhOQpyWUCnr|y>Sv~Hhnvu>qsu5O}kgf2_hPuE4)PS;#lN9U@2p}nU)uRWyQrd_R_ zul-azT02aetc}sWul3c|)2cNuHNR>uYL03wn(sA>G*dMnYVtG#HQhBp(^})D(W~F6 ze^*~o|E$L9|EQO$XR61k3)HFVp6W1lpt_N|hFY$As=A>%srpg1QMFt(TUDkiQl+c< zs3KKuRNg9gm6PkAuD7f^V!zp2z}^D(7O=N~y#?$oU~hr{#TL*zdO0{)zhVxo`06yo zHJu+T)A+H=5Ps~O%8#7}^JB*peykY8kL3e(j$W>n`3G1-Wo4Mmdv~yFe}1e?;>RwD z{Mfl4KX&TNj~)B)V?}R%EKlHVU2ren)&=+Ey*tRY2R~NE^JAAde(cXHB ze(Vy!kDXidV<&%p?AVGQEByGe-1qI)+48r1yLGm(1@GO#u0H%&*_t4-u&39 zDL-~>!jBb=`LVnaZ|i~$ysZoN;=Mb_wIM%NHsHrD_4%=LJ$~#|mmfQN@?%9Eek}Lk zZCy}p-qr=V^WGikT8kem-{Z$FHTkh~4Sww8#*ZEK{8*vm$8s%i>jE{rtqWB1-W}ko z;>Su?e(a*;$IdSN*vXk6J38@Wg(E+fD|lNMAlF)t#;(>62h#5Xi2qj`W%^h8QoT|) zO;=01Kx@!^qw!O3(EqM4(Y@4-)46D;YTwh$*LbPFw$2R5);-h}YyZ}Mtaa9Wrm3l( zr$4Pv)?L$OX@AoeXg)D{SJMY?trel_M|pR zb48P({#9L|U#kz%ZP#_u?$^d@PG}O{=Mtc1pC(3qT-{GUM_*UB zRM$-Voi<3bRRh#})zSKidaZ7@&QrTY>#bR%2~=;f?uh+nZvlG?*jvEf0`?ZLw}8C` z>@8q#0jXL*qi}J64~|saj;52~-u#N&f;1A`&Lp=XB)AbwZmA@=YTKg0Bsg#Tn-mh9 zw7qQ*2@cs_8AyU{wqSsxLg|o9o?$ZgCjpZ=i3CjML=rHW`;maj+?NFBZIXRRz+~=C z0w!|;37E{iDl+#Z&oG&Lkbub?PXZ=$90{1r-ATY?jwJz;IfevG=4cWynY)pI$=tOf za~JXqlR1h6Oy)=uFqu1(fXN&|0w!}f37E`bBw#XkA_0>bl7Pt!Dl!|%Gfd`C5-^$H zCjpbWBMF$y9Z0}rZchRxb2}0+ncI?p$=rqnOy-b^%)#UtCUX!8n9PABU@`}gfXUpN z1Waas5-^!tk$}nUM*=3ZFA12;Eh{p&AkQ$FeMrD$ZcYLwb2Ab!nY~HCWNu0VCUX-K zFqs>ZfXUp51WaZ_MP@JZ43oJb37E_cNWf&SPXZ=$JrXdP>ym)U>`4M9a~%>enLS9r zWUgJ2*_}MYWUfU5Ci8nFU^3Sv0h74~37E`oBw#Y@Nx)>*k$}mpB>|IJQ;}Iso?$Yp zNWf%vB>|IJNdhLb3kjIa&Lm(mJCT6N>_`G8vw{RnX1PM4bk)kN^8u7KW%@R{q1q*y zQyNz_P!+kZQeLt?X}{T9z}^D(7O=N~y#?$o@ITT54iDN%tLeYGqO^%Mvu%M zF|shLBqQ6LT9}tpSUjvXtJa!a!t9@C4yi>^lw71tE_?Pj=Z;d`9oN`KM3r?)G#hdtSq zcMW^_J9edF*wfn;+QXjg%DaZW!#j4RV%XE$723m|?22&MSA74k(wE5ePxZI-XRTNJ z@6dm*U!tFBy#k=*f282C@20&4>@8q#0ecJBTfp7|_7Ensf}dkffGz}^D(7O=N~y#?$oU~d6?3)ox0-U9#4 z7EmhO$r%IH+Xt?1A1L2GaC!T{nRyW4^!9<{+Xsrb59BTicXAS-V*bBUHbSQJ(0HqY zl%1TrJ0{C>tWQXLD>Er;c{#X@3#&ga%ynGMIH>I1w0`||6*DL*R_6$Cfy^Zj)aD)q zv?jlQ5KwVz=Ree(m2Ta`ihmR{ke>@;GMMwTXe*Q7-jPJU;>U_Ah0nnhH2+XrG;e7M z#y|@OY#sq02sC%ctkTu3hssT_Y`>ooc4X({>EACzQ% zrHTAlP*Iu5rCQ?7kh!jfVyjgWZ!0PS9NHPO*ufZk_#H+5Sd&UPnJTq*X4iHF7vunXv$^ z`|N(`_IaUr;U9kZe6I*xNA`~;F2sbzAvZblk-PGdDzkkC}*2f10n zCthWc%OP{IsU*aMBqtsuEpgiNEO8>N7`ARS+?n@2c%|K7T%#;D&YSLwpQRh&CVO0{C}L|0Qpeq{lD3*x(|^+`raUr-E|7Y|CNRxHmZSZx^;tL1!F;t(cR(rTlaxa^B=*4 z)-mv}phVC%`J%`}ib>B$&Of9C(l-ZOf4FOe4RZ1CUu=YD=L2!amNJvBTH^G*0hhQm-hfM- zZt*T4PFtQOPJ|tgb$~a%d<-r&mcfqyIOBlS|5#pSC&HZp*MZ~m!C2-4!435y{N2Z& z!RX)Si7gbheB%7mx`Fh)0hhQm-hfM-Zt*T4PFtQOPK5n$#^5Wh7aJo!^@Q!uyxdvq z_E_wp*l1k7pf3D-X-}hZNepafjDWJDEI54IKCwr2O?B2Fkh(SC0++@baDmgU-37pD z%d^03)?m;9AGo^n2Jq+nmiUY;2<}DKN7(zNAS(nMG>dR$%XdN~^2D&Az7clZoS z{NXq{l^TyD=Cs1`nRk#Ty%rqzcm{lL>PSuBz^69^m?R?L$f2UXTIT6k})(H+xX$BYEF1KvGu^M$xlEddQG4R098gNgc8=f)P6Slq{2&U!M!M`@# zDROL^T&pFngv_#Zk`Qk%Iq~+=5~nTC5+}k}H|xRBj;V0q#JfoOQ$t*+y8?E-Yz7T8 zB4NPRaB#kVJy^T*bo^y!EpnUmn>gY&|6d~0X*G3KKFT)Ev5x)a*)nOb{Vy}u;de9s zM@SHTV-Un~|I!$PAddT|Te&mJmF}OmJloBPuwwAw#YW$#CU9h{&Txh~0hVVyMz@TM z!OTB(aNxaqFd)DU?>O8K^(udbE>27k`&r__N6vQ!LG+yg7r8XffQy`N=`KJ{Tb@Nu zgbgD;F#g&w4qkKcz_Y!(!{_KAN{YXOo+i}6zaPJWCiV5jA;&$S_Jf1?cF0k2N0u^E z?dqIC5Ort31ul&<-~y*xx(k5QmS=$zVfo%eZurx#&EPe6IrI!|0WWOn0QF5eqs0ZU z&{Ay<*z~L~UN`F*`mJ#~`0T)Jkz+j5o$m~)jB#p31DCWk&VWmrZs{%{ODs`{@#)xG(+9wfF~f z{^xffePh6N{n8i%t{0?RxeL4?ZF!bB5mv0)90FH(?*V(8dE$u9a=74LG7fE=Vf-NC z6?$E39a^v~7S~Etz!iqMunWi$iQM!aA31+)2hw*2T;$R?11@s9rMmz*ZFv?s5jNNO z$+GT~2|%^E1a6*~4mNh=%px`N(K|AaAoVZLhEKl;?ZEso{q(DKZ8> zo?HiCjeKS_A5!2uPm=K2gN;R&Y~~u(BBwt`my2AQIl5frbSKdYkkgiDkrQG0s=eQV zZwg#s?Ot7A<4+nuU8xN2nOhHM4{nQ*cWan<*9*OQGzl&p{T+Ou+b3?6+x&mA&Hp#I z`u}jpczLF**fuj)n)@=7p*m*}Ox+m-bKJi)&LEiM{^^$PjB=&>r!CL+GMh7aknIOO zI=+t+JTF>iKJ&(YA2i4Ha}R)%=YBw6bdE%Renat`PtL>9YhHr2sy0Ae!Ic+3ojBZ;y{}lR`&@KMBf>3y`VJCfa?Y6mhJ*CNL!vIZgU13_KgQy zf&iMjv=Fu^yo`n}=?`B|odxQZ55S8~Jx8mWxPflhy5bAovrxh2x#B)S%FGR_MNZ!t zaFI*n47kYYmhJ-NwB=dkL|E?F(E${sMZ-Foo;Y!Sb1?bG`uK$`3fCMp8q2!=vCU^$ z0`i}70i>GWhmjxN6N%hZpO2hx27~BZ11@rDtN|A}-P&D%oVGlRoCq)2+7GVYdj#Om zx1w#b<*4C{YsE`L{uEsiUe`Gtq)t%fur0tQI-_DY{(b(oE6iBBwit zR)CzgJd2zN%gui^z$cb!*St{<7kInxNlj$nv|F=-Kb?)rgQ$9#GOakX$X7;JBcLme;23$8NjW^)BLAu4e zzzx!tXL}bBmY=A%1AOMu0Y9l#8Y1|O#*J~kM?(h>`uQ;7qb#^7bs1>Zdlm@a zv=5D2kq+kt<$^oc+QF9Pi-k zUf44X)^5}r{L=A6sL#?swj zX+51~+T2Fq+goY)izP`SPv)B%^O0Bi)`tExT`qEIrs;B#)15~vKu%kpMNWj6%eN_&k`rX=?=S*DyKbuq;ZEkGq0kpAD=>oy0JLvMi2NRLWftm zH$kQ=ukg&11F`G$y5bJ&?DhX1nr5nCWtelkaz_Lg>4M5Y8L`XY>6Y;VXpAY0I<7iEzT)c5wGE?mNz92jSes5nxC1aA>Fz1=h{>g@;>DL8lG{K+hZN z@t&k^@VnYI#S(8-EphrTflFK(m%t@Xw~QAMr!CJCC&B?MQ*oQNU!brLEI<=>0jRF^ z#xHto25p=700&1s2G>_5!?s~Z(b&Cv@V(KiL>|^nethD5lMq7RByfpKV-mQ;=~nRq z;rF$rAabgOs)aoX}MaUv{l8DxUzFKq%vInd(##c&Y*raiuX z$s4%ii#!f|Z*UwHkZ-IhX+X(G??=C|a}*sU_isn6Ku5|?JiE|)mnDYXLP zwB=diM7UX!2Tbe|gwIBogP$5j!s{)YKxN#moi(SIpvIS1fhjl#rcYUBX;Cm?=h}%P zzZNdbY*F2W-4JRMcDbOXnXt=5=vD*$T^}q!?J-bh)wJ+CvtlV?%N0_}J_Jb((rA>s4IV_5c4H*Z;TwmrsyD zvid(&Utp8uxPR)tfXN`i>VI72O7~CA7clZ{HzUIGWy!_xHOzW{>9Tn!zFRxAxhM|4 zAL)j3niYd3gFN8u4w3NlGxzaYCk_1RZX2;>Vmm%^)(YA-f~qlK_N#LFAnL||3x@n~ z0=3GWt6YhkwmgfR2)p*ph1%mjVAqf9!Jwz((MZJ+6cv~Xd$`X8U%be~n{vNFiSu6p z-)!eyt8W@a5;wQy6KBgr+b&Wy1_I*LjRBVnAx^Dw=PFkcr!CJCC&G#uC;A#o%KYHS zZ$q)cScDIHMxpVkqtS=~i%_>IZJ|rgMsW1gmN4n`L)7e{t4QRgHhko4nP}Tbs?I=w zoVqjMf+5JMmF`^SO60WVS>!}mHYvwwsbR#|UiiXoWj*kbv**wh*jsY!lX23m4v9{k7 z$Vbk8fZLGMpRdbBF3o&hE^@jPY6Zw?%d^Od@QMiw!9GJFY}@#%F>0R!{Pz*zHq0`==i4bYT;t@u?)>P~Sp zgiWA2V?cFXa9bzhoZ*wk7;uTxt=t8~Y0I<3iLj#UL^)dB?Ey&IH5fmS9fRElH#3e+ z_W?^wZW_P6JOtite;cGstJ$glpNn@ac+)pj9Pun;b;f|oy5IuR(ij6SX}XoWfHZA+ zmbA?n_%}+13#Z88o9SzfdqU*6wz@kUwDT)4y*4rqS-uoqyjuY4>5szg0|H^XyP}gg z(&o_W4(D{&1t-Y4Ki*5-7%&+sPv29k+_}n?UXZpti`-@mdYpZQ{=Cr#KHWYXX$z9! z=&#QL?T)#q<5G7N|NSO3X>Kk)(Q_&)v22B-vXNLX_sTuZ^j+ z;ya^SXQ!jyD<+E@#oXb`^q_s6L54PUHI{|EXs zYKg8)J0{W`WOm`31NOwCjSsbN4Y;JGF$Y}IbgOp(Y1;BEX(9}!7UOBRzW~do?hW;N zVL}hibjOx!YIt~cdvLn7BfQh2FI>651x&g8A%42>PqBq!RCOOAbe9GveTRFhM;dRy zwS#nvcYz(GEzc6Sd4qlf65xRwS5V<+n^AtU9@jph!^4i&$In|%HLjl(0{1Tn!3|_x z;L#HXOO3FfvAC+i6v;=v(=FZw$Z5;7$cga9>)CMP!vxT}$5u2j zE(~nF;|A5H?(p6K3yA3&i!UsH3f7hSpxN>MFlA(hSmK?lB~ITOa0j_G)__Z#ZtX50 zPFtQOPK2MM-RSY%T%1n=MUxRP3IVx+N2Jgw&xL$MgED} z?lQuErldfhFERM;*b-#y^pz!h&?(eBceXfZP-YIV7CHS%x?JSaOw#2dr#p*QfSk5G zi`+Iz_tfYn@R!m*(B9E?@s8vRV0_OMn025SO?cS^?#Vx6DaYZkS6vNET7baio@c}& z533eA{Yko9B{$FP9QC;_bpE$o#q`Nk_V314W z4Y;|;jP z=@#z-;wUC1a0iNMFda1Pl7>tFD6#Y$ zH36Mn)J!b$m}-&J_Xb?#(s% z%d^OdaMy>w0RNjhcw~DUyno~z%cjN=__NhbG0M-xJAYmY#ypsXwtixPlQtBxx}TJsLLfzcP6cXIBj{BI1#QnQvu&>aKew<`oov%$7uidK6tNu z7MeY~9n60dh~^yf!>2wegQ??pqshYx|6lz7aGB0S<4xB8i*rntXUm4u^!-AuGE-u8 z{VbUN7YN*SCZ#b6+;t}DR`G&$CTYvF{fr3Lo@s>5TQ>nC8*nLgOZut{7$D>L`w6X#cqbk_(c9)U|-8jrvwPPd2`5T`BA z5+}lnLAO2zH710^hbB)vX_gF*IoJqq@$|s04kqCBwd|}h(GZCNn3J_U7 znEO_XoW4omBA3P_aFNrk;swZQ%d^OBCP6lGDT*izLcTBZ;py<5sQ$J**fuTJnABu2 zF5R&b?K=Jgnt84bdJ%U5)PA->+-fpq=04RTr*9It$fYp}T;z1CcmZQ9=yMW<-6o)_(ef|ys`fj*m3nNobh!7u_a?~K5~A^ z7);+JaFI)661d3eR`CMlwB=dkM0m`Fm8kXS?s&Bt;HH6>EYI$S;nk;`!ls?gIIz*@ z;JejE@HF!>esX#OymYjS$d7VULbb%{n*=U#X-oo_INd58Lt;Rk zpH{=C<$ut^VJ({$D5S`hWiw z|38@j5yC&SNOxgy($Bc-Oh{u5xa&;Nt=$FdOwg8R`xz0wT+P-IE-D*ej^s<%2jTbhXIq!Obar;(Y%fLf;v1iA&=QxWws} z?gHYpIL6B^jAE7rBBh40VGMccnPf%12B2b&k9!GyBk@S_Vm z@r+ydMSg~u2Jw+s8RYbx0T;P6&VY-YZs{&SPFtQuPK2L3UqVL%{PFI&9=Q0IOt>y_ zExHsw39V0Cwe#x+&G0{)!{M-gjf}m|CE$WpGsJzEl$i(eiB}oq^qm2hxHQgyOPp@$ zE+9@@o+VC%<$uh34Q3h6fClr!@w6}Zqb6?Y@RY6)e7YtIuAbcnJo~*a+%bD4czDlk zmpbx4B9Z40;9G+#gPi&TfL!9zSOYF`y0yE2IBj{BI1#>Cuo)FLxM6HRE)R!j2jji5 z`%#zdIpD9dm8gH*P&nxaGdK~t8i%g*#)r@C6*<`_E1B;Ms*G`JSMqT|OXCc;GWAKugd(uqT^FN0*jvEf0`?ZLw}8C`>@D!_7AP~N^9#i)wx8lj+uRL?dOruV8fU<(E;o^3`nFKn ziW;DJ@&>fO+i{U&+?>iMUgg{#{Z&4=#HE>}%Oy^C7Oj9dZF!bB5tf&q-fQX6Y3R=D vhsMG+TLy!3^G3nzzHY|hQ%2(OhH1b&W}Bty&mDF?T