diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index e8e88b70..88455319 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -450,6 +450,7 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC): self, frame_annotation: types.FrameAnnotation, sequence_annotation: types.SequenceAnnotation, + load_blobs: bool = True, ) -> FrameDataSubtype: """An abstract method to build the frame data based on raw frame/sequence annotations, load the binary data and adjust them according to the metadata. @@ -465,8 +466,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): Beware that modifications of frame data are done in-place. Args: - dataset_root: The root folder of the dataset; all the paths in jsons are - specified relative to this root (but not json paths themselves). + dataset_root: The root folder of the dataset; all paths in frame / sequence + annotations are defined w.r.t. this root. Has to be set if any of the + load_* flabs below is true. load_images: Enable loading the frame RGB data. load_depths: Enable loading the frame depth maps. load_depth_masks: Enable loading the frame depth map masks denoting the @@ -494,7 +496,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): path_manager: Optionally a PathManager for interpreting paths in a special way. """ - dataset_root: str = "" + dataset_root: Optional[str] = None load_images: bool = True load_depths: bool = True load_depth_masks: bool = True @@ -510,6 +512,25 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): box_crop_context: float = 0.3 path_manager: Any = None + def __post_init__(self) -> None: + load_any_blob = ( + self.load_images + or self.load_depths + or self.load_depth_masks + or self.load_masks + or self.load_point_clouds + ) + if load_any_blob and self.dataset_root is None: + raise ValueError( + "dataset_root must be set to load any blob data. " + "Make sure it is set in either FrameDataBuilder or Dataset params." + ) + + if load_any_blob and not os.path.isdir(self.dataset_root): # pyre-ignore + raise ValueError( + f"dataset_root is passed but {self.dataset_root} does not exist." + ) + def build( self, frame_annotation: types.FrameAnnotation, @@ -567,7 +588,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): if bbox_xywh is None and fg_mask_np is not None: bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr) - frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long) + frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float) if frame_annotation.image is not None: image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) @@ -612,7 +633,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): def _load_fg_probability( self, entry: types.FrameAnnotation ) -> Tuple[np.ndarray, str]: - full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore + assert self.dataset_root is not None and entry.mask is not None + full_path = os.path.join(self.dataset_root, entry.mask.path) fg_probability = load_mask(self._local_path(full_path)) if fg_probability.shape[-2:] != entry.image.size: raise ValueError( @@ -647,7 +669,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): fg_probability: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, str, torch.Tensor]: entry_depth = entry.depth - assert entry_depth is not None + assert self.dataset_root is not None and entry_depth is not None path = os.path.join(self.dataset_root, entry_depth.path) depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) @@ -657,6 +679,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): if self.load_depth_masks: assert entry_depth.mask_path is not None + # pyre-ignore mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) depth_mask = load_depth_mask(self._local_path(mask_path)) else: @@ -705,6 +728,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): ) if path.startswith(unwanted_prefix): path = path[len(unwanted_prefix) :] + assert self.dataset_root is not None return os.path.join(self.dataset_root, path) def _local_path(self, path: str) -> str: diff --git a/pytorch3d/implicitron/dataset/orm_types.py b/pytorch3d/implicitron/dataset/orm_types.py new file mode 100644 index 00000000..5736ab4b --- /dev/null +++ b/pytorch3d/implicitron/dataset/orm_types.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This functionality requires SQLAlchemy 2.0 or later. + +import math +import struct +from typing import Optional, Tuple + +import numpy as np + +from pytorch3d.implicitron.dataset.types import ( + DepthAnnotation, + ImageAnnotation, + MaskAnnotation, + PointCloudAnnotation, + VideoAnnotation, + ViewpointAnnotation, +) + +from sqlalchemy import LargeBinary +from sqlalchemy.orm import ( + composite, + DeclarativeBase, + Mapped, + mapped_column, + MappedAsDataclass, +) +from sqlalchemy.types import TypeDecorator + + +# these produce policies to serialize structured types to blobs +def ArrayTypeFactory(shape): + class NumpyArrayType(TypeDecorator): + impl = LargeBinary + + def process_bind_param(self, value, dialect): + if value is not None: + if value.shape != shape: + raise ValueError(f"Passed an array of wrong shape: {value.shape}") + return value.astype(np.float32).tobytes() + return None + + def process_result_value(self, value, dialect): + if value is not None: + return np.frombuffer(value, dtype=np.float32).reshape(shape) + return None + + return NumpyArrayType + + +def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)): + format_symbol = { + float: "f", # float32 + int: "i", # int32 + }[dtype] + + class TupleType(TypeDecorator): + impl = LargeBinary + _format = format_symbol * math.prod(shape) + + def process_bind_param(self, value, _): + if value is None: + return None + + if len(shape) > 1: + value = np.array(value, dtype=dtype).reshape(-1) + + return struct.pack(TupleType._format, *value) + + def process_result_value(self, value, _): + if value is None: + return None + + loaded = struct.unpack(TupleType._format, value) + if len(shape) > 1: + loaded = _rec_totuple( + np.array(loaded, dtype=dtype).reshape(shape).tolist() + ) + + return loaded + + return TupleType + + +def _rec_totuple(t): + if isinstance(t, list): + return tuple(_rec_totuple(x) for x in t) + + return t + + +class Base(MappedAsDataclass, DeclarativeBase): + """subclasses will be converted to dataclasses""" + + +class SqlFrameAnnotation(Base): + __tablename__ = "frame_annots" + + sequence_name: Mapped[str] = mapped_column(primary_key=True) + frame_number: Mapped[int] = mapped_column(primary_key=True) + frame_timestamp: Mapped[float] = mapped_column(index=True) + + image: Mapped[ImageAnnotation] = composite( + mapped_column("_image_path"), + mapped_column("_image_size", TupleTypeFactory(int)), + ) + + depth: Mapped[DepthAnnotation] = composite( + mapped_column("_depth_path", nullable=True), + mapped_column("_depth_scale_adjustment", nullable=True), + mapped_column("_depth_mask_path", nullable=True), + ) + + mask: Mapped[MaskAnnotation] = composite( + mapped_column("_mask_path", nullable=True), + mapped_column("_mask_mass", index=True, nullable=True), + mapped_column( + "_mask_bounding_box_xywh", + TupleTypeFactory(float, shape=(4,)), + nullable=True, + ), + ) + + viewpoint: Mapped[ViewpointAnnotation] = composite( + mapped_column( + "_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True + ), + mapped_column( + "_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True + ), + mapped_column( + "_viewpoint_focal_length", TupleTypeFactory(float), nullable=True + ), + mapped_column( + "_viewpoint_principal_point", TupleTypeFactory(float), nullable=True + ), + mapped_column("_viewpoint_intrinsics_format", nullable=True), + ) + + +class SqlSequenceAnnotation(Base): + __tablename__ = "sequence_annots" + + sequence_name: Mapped[str] = mapped_column(primary_key=True) + category: Mapped[str] = mapped_column(index=True) + + video: Mapped[VideoAnnotation] = composite( + mapped_column("_video_path", nullable=True), + mapped_column("_video_length", nullable=True), + ) + point_cloud: Mapped[PointCloudAnnotation] = composite( + mapped_column("_point_cloud_path", nullable=True), + mapped_column("_point_cloud_quality_score", nullable=True), + mapped_column("_point_cloud_n_points", nullable=True), + ) + # the bigger the better + viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None) diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py new file mode 100644 index 00000000..605e8b52 --- /dev/null +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -0,0 +1,735 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import hashlib +import json +import logging +import os +from dataclasses import dataclass +from typing import ( + Any, + ClassVar, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import numpy as np +import pandas as pd +import sqlalchemy as sa +import torch +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase + +from pytorch3d.implicitron.dataset.frame_data import ( # noqa + FrameData, + FrameDataBuilder, + FrameDataBuilderBase, +) +from pytorch3d.implicitron.tools.config import ( + registry, + ReplaceableBase, + run_auto_creation, +) +from sqlalchemy.orm import Session + +from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation + + +logger = logging.getLogger(__name__) + + +_SET_LISTS_TABLE: str = "set_lists" + + +@registry.register +class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore + """ + A dataset with annotations stored as SQLite tables. This is an index-based dataset. + The length is returned after all sequence and frame filters are applied (see param + definitions below). Indices can either be ordinal in [0, len), or pairs of + (sequence_name, frame_number); with the performance of `dataset[i]` and + `dataset[sequence_name, frame_number]` being same. A faster way to get metadata only + (without blobs) is `dataset.meta[idx]` indexing; it requires box_crop==False. + With ordinal indexing, the sequences are NOT guaranteed to span contiguous index + ranges, and frame numbers are NOT guaranteed to be increasing within a sequence. + Sequence-aware batch samplers have to use `sequence_[frames|indices]_in_order` + iterators, which are efficient. + + This functionality requires SQLAlchemy 2.0 or later. + + Metadata-related args: + sqlite_metadata_file: A SQLite file containing frame and sequence annotation + tables (mapping to SqlFrameAnnotation and SqlSequenceAnnotation, + respectively). + dataset_root: A root directory to look for images, masks, etc. It can be + alternatively set in `frame_data_builder` args, but this takes precedence. + subset_lists_file: A JSON/sqlite file containing the lists of frames + corresponding to different subsets (e.g. train/val/test) of the dataset; + format: {subset: [(sequence_name, frame_id, file_path)]}. All entries + must be present in frame_annotation metadata table. + path_manager: a facade for non-POSIX filesystems. + subsets: Restrict frames/sequences only to the given list of subsets + as defined in subset_lists_file (see above). Applied before all other + filters. + remove_empty_masks: Removes the frames with no active foreground pixels + in the segmentation mask (needs frame_annotation.mask.mass to be set; + null values are retained). + pick_frames_sql_clause: SQL WHERE clause to constrain frame annotations + NOTE: This is a potential security risk! The string is passed to the SQL + engine verbatim. Don’t expose it to end users of your application! + pick_categories: Restrict the dataset to the given list of categories. + pick_sequences: A Sequence of sequence names to restrict the dataset to. + exclude_sequences: A Sequence of the names of the sequences to exclude. + limit_sequences_to: Limit the dataset to the first `limit_sequences_to` + sequences (after other sequence filters have been applied but before + frame-based filters). + limit_to: Limit the dataset to the first #limit_to frames (after other + filters have been applied, except n_frames_per_sequence). + n_frames_per_sequence: If > 0, randomly samples `n_frames_per_sequence` + frames in each sequences uniformly without replacement if it has + more frames than that; applied after other frame-level filters. + seed: The seed of the random generator sampling `n_frames_per_sequence` + random frames per sequence. + """ + + frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation + + sqlite_metadata_file: str = "" + dataset_root: Optional[str] = None + subset_lists_file: str = "" + eval_batches_file: Optional[str] = None + path_manager: Any = None + subsets: Optional[List[str]] = None + remove_empty_masks: bool = True + pick_frames_sql_clause: Optional[str] = None + pick_categories: Tuple[str, ...] = () + + pick_sequences: Tuple[str, ...] = () + exclude_sequences: Tuple[str, ...] = () + limit_sequences_to: int = 0 + limit_to: int = 0 + n_frames_per_sequence: int = -1 + seed: int = 0 + remove_empty_masks_poll_whole_table_threshold: int = 300_000 + # we set it manually in the constructor + # _index: pd.DataFrame = field(init=False) + + frame_data_builder: FrameDataBuilderBase + frame_data_builder_class_type: str = "FrameDataBuilder" + + def __post_init__(self) -> None: + if sa.__version__ < "2.0": + raise ImportError("This class requires SQL Alchemy 2.0 or later") + + if not self.sqlite_metadata_file: + raise ValueError("sqlite_metadata_file must be set") + + if self.dataset_root: + frame_builder_type = self.frame_data_builder_class_type + getattr(self, f"frame_data_builder_{frame_builder_type}_args")[ + "dataset_root" + ] = self.dataset_root + + run_auto_creation(self) + + # pyre-ignore + self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}") + + sequences = self._get_filtered_sequences_if_any() + + if self.subsets: + index = self._build_index_from_subset_lists(sequences) + else: + # TODO: if self.subset_lists_file and not self.subsets, it might be faster to + # still use the concatenated lists, assuming they cover the whole dataset + index = self._build_index_from_db(sequences) + + if self.n_frames_per_sequence >= 0: + index = self._stratified_sample_index(index) + + if len(index) == 0: + raise ValueError(f"There are no frames in the subsets: {self.subsets}!") + + self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore + + self.eval_batches = None # pyre-ignore + if self.eval_batches_file: + self.eval_batches = self._load_filter_eval_batches() + + logger.info(str(self)) + + def __len__(self) -> int: + # pyre-ignore[16] + return len(self._index) + + def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData: + """ + Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair + """ + return self._get_item(frame_idx, True) + + @property + def meta(self): + """ + Allows accessing metadata only without loading blobs using `dataset.meta[idx]`. + Requires box_crop==False, since in that case, cameras cannot be adjusted + without loading masks. + + Returns: + FrameData objects with blob fields like `image_rgb` set to None. + + Raises: + ValueError if dataset.box_crop is set. + """ + return SqlIndexDataset._MetadataAccessor(self) + + @dataclass + class _MetadataAccessor: + dataset: "SqlIndexDataset" + + def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData: + return self.dataset._get_item(frame_idx, False) + + def _get_item( + self, frame_idx: Union[int, Tuple[str, int]], load_blobs: bool = True + ) -> FrameData: + if isinstance(frame_idx, int): + if frame_idx >= len(self._index): + raise IndexError(f"index {frame_idx} out of range {len(self._index)}") + + seq, frame = self._index.index[frame_idx] + else: + seq, frame, *rest = frame_idx + if (seq, frame) not in self._index.index: + raise IndexError( + f"Sequence-frame index {frame_idx} not found; was it filtered out?" + ) + + if rest and rest[0] != self._index.loc[(seq, frame), "_image_path"]: + raise IndexError(f"Non-matching image path in {frame_idx}.") + + stmt = sa.select(self.frame_annotations_type).where( + self.frame_annotations_type.sequence_name == seq, + self.frame_annotations_type.frame_number + == int(frame), # cast from np.int64 + ) + seq_stmt = sa.select(SqlSequenceAnnotation).where( + SqlSequenceAnnotation.sequence_name == seq + ) + with Session(self._sql_engine) as session: + entry = session.scalars(stmt).one() + seq_metadata = session.scalars(seq_stmt).one() + + assert entry.image.path == self._index.loc[(seq, frame), "_image_path"] + + frame_data = self.frame_data_builder.build( + entry, seq_metadata, load_blobs=load_blobs + ) + + # The rest of the fields are optional + frame_data.frame_type = self._get_frame_type(entry) + return frame_data + + def __str__(self) -> str: + # pyre-ignore[16] + return f"SqlIndexDataset #frames={len(self._index)}" + + def sequence_names(self) -> Iterable[str]: + """Returns an iterator over sequence names in the dataset.""" + return self._index.index.unique("sequence_name") + + # override + def category_to_sequence_names(self) -> Dict[str, List[str]]: + stmt = sa.select( + SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name + ).where( # we limit results to sequences that have frames after all filters + SqlSequenceAnnotation.sequence_name.in_(self.sequence_names()) + ) + with self._sql_engine.connect() as connection: + cat_to_seqs = pd.read_sql(stmt, connection) + + return cat_to_seqs.groupby("category")["sequence_name"].apply(list).to_dict() + + # override + def get_frame_numbers_and_timestamps( + self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None + ) -> List[Tuple[int, float]]: + """ + Implements the DatasetBase method. + + NOTE: Avoid this function as there are more efficient alternatives such as + querying `dataset[idx]` directly or getting all sequence frames with + `sequence_[frames|indices]_in_order`. + + Return the index and timestamp in their videos of the frames whose + indices are given in `idxs`. They need to belong to the same sequence! + If timestamps are absent, they are replaced with zeros. + This is used for letting SceneBatchSampler identify consecutive + frames. + + Args: + idxs: a sequence int frame index in the dataset (it can be a slice) + subset_filter: must remain None + + Returns: + list of tuples of + - frame index in video + - timestamp of frame in video, coalesced with 0s + + Raises: + ValueError if idxs belong to more than one sequence. + """ + + if subset_filter is not None: + raise NotImplementedError( + "Subset filters are not supported in SQL Dataset. " + "We encourage creating a dataset per subset." + ) + + index_slice, _ = self._get_frame_no_coalesced_ts_by_row_indices(idxs) + # alternatively, we can use `.values.tolist()`, which may be faster + # but returns a list of lists + return list(index_slice.itertuples()) + + # override + def sequence_frames_in_order( + self, seq_name: str, subset_filter: Optional[Sequence[str]] = None + ) -> Iterator[Tuple[float, int, int]]: + """ + Overrides the default DatasetBase implementation (we don’t use `_seq_to_idx`). + Returns an iterator over the frame indices in a given sequence. + We attempt to first sort by timestamp (if they are available), + then by frame number. + + Args: + seq_name: the name of the sequence. + subset_filter: subset names to filter to + + Returns: + an iterator over triplets `(timestamp, frame_no, dataset_idx)`, + where `frame_no` is the index within the sequence, and + `dataset_idx` is the index within the dataset. + `None` timestamps are replaced with 0s. + """ + # TODO: implement sort_timestamp_first? (which would matter if the orders + # of frame numbers and timestamps are different) + rows = self._index.index.get_loc(seq_name) + if isinstance(rows, slice): + assert rows.stop is not None, "Unexpected result from pandas" + rows = range(rows.start or 0, rows.stop, rows.step or 1) + else: + rows = np.where(rows)[0] + + index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices( + rows, seq_name, subset_filter + ) + index_slice["idx"] = idx + + yield from index_slice.itertuples(index=False) + + # override + def get_eval_batches(self) -> Optional[List[Any]]: + """ + This class does not support eval batches with ordinal indices. You can pass + eval_batches as a batch_sampler to a data_loader since the dataset supports + `dataset[seq_name, frame_no]` indexing. + """ + return self.eval_batches + + # override + def join(self, other_datasets: Iterable[DatasetBase]) -> None: + raise ValueError("Not supported! Preprocess the data by merging them instead.") + + # override + @property + def frame_data_type(self) -> Type[FrameData]: + return self.frame_data_builder.frame_data_type + + def is_filtered(self) -> bool: + """ + Returns `True` in case the dataset has been filtered and thus some frame + annotations stored on the disk might be missing in the dataset object. + Does not account for subsets. + + Returns: + is_filtered: `True` if the dataset has been filtered, else `False`. + """ + return ( + self.remove_empty_masks + or self.limit_to > 0 + or self.limit_sequences_to > 0 + or len(self.pick_sequences) > 0 + or len(self.exclude_sequences) > 0 + or len(self.pick_categories) > 0 + or self.n_frames_per_sequence > 0 + ) + + def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]: + # maximum possible query: WHERE category IN 'self.pick_categories' + # AND sequence_name IN 'self.pick_sequences' + # AND sequence_name NOT IN 'self.exclude_sequences' + # LIMIT 'self.limit_sequence_to' + + stmt = sa.select(SqlSequenceAnnotation.sequence_name) + + where_conditions = [ + *self._get_category_filters(), + *self._get_pick_filters(), + *self._get_exclude_filters(), + ] + if where_conditions: + stmt = stmt.where(*where_conditions) + + if self.limit_sequences_to > 0: + logger.info( + f"Limiting dataset to first {self.limit_sequences_to} sequences" + ) + # NOTE: ROWID is SQLite-specific + stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to) + + if not where_conditions and self.limit_sequences_to <= 0: + # we will not need to filter by sequences + return None + + with self._sql_engine.connect() as connection: + sequences = pd.read_sql_query(stmt, connection)["sequence_name"] + logger.info("... retained %d sequences" % len(sequences)) + + return sequences + + def _get_category_filters(self) -> List[sa.ColumnElement]: + if not self.pick_categories: + return [] + + logger.info(f"Limiting dataset to categories: {self.pick_categories}") + return [SqlSequenceAnnotation.category.in_(self.pick_categories)] + + def _get_pick_filters(self) -> List[sa.ColumnElement]: + if not self.pick_sequences: + return [] + + logger.info(f"Limiting dataset to sequences: {self.pick_sequences}") + return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)] + + def _get_exclude_filters(self) -> List[sa.ColumnOperators]: + if not self.exclude_sequences: + return [] + + logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}") + return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)] + + def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame: + assert self.subsets is not None + with open(subset_lists_path, "r") as f: + subset_to_seq_frame = json.load(f) + + seq_frame_list = sum( + ( + [(*row, subset) for row in subset_to_seq_frame[subset]] + for subset in self.subsets + ), + [], + ) + index = pd.DataFrame( + seq_frame_list, + columns=["sequence_name", "frame_number", "_image_path", "subset"], + ) + return index + + def _load_subsets_from_sql(self, subset_lists_path: str) -> pd.DataFrame: + subsets = self.subsets + assert subsets is not None + # we need a new engine since we store the subsets in a separate DB + engine = sa.create_engine(f"sqlite:///{subset_lists_path}") + table = sa.Table(_SET_LISTS_TABLE, sa.MetaData(), autoload_with=engine) + stmt = sa.select(table).where(table.c.subset.in_(subsets)) + with engine.connect() as connection: + index = pd.read_sql(stmt, connection) + + return index + + def _build_index_from_subset_lists( + self, sequences: Optional[pd.Series] + ) -> pd.DataFrame: + if not self.subset_lists_file: + raise ValueError("Requested subsets but subset_lists_file not given") + + logger.info(f"Loading subset lists from {self.subset_lists_file}.") + + subset_lists_path = self._local_path(self.subset_lists_file) + if subset_lists_path.lower().endswith(".json"): + index = self._load_subsets_from_json(subset_lists_path) + else: + index = self._load_subsets_from_sql(subset_lists_path) + index = index.set_index(["sequence_name", "frame_number"]) + logger.info(f" -> loaded {len(index)} samples of {self.subsets}.") + + if sequences is not None: + logger.info("Applying filtered sequences.") + sequence_values = index.index.get_level_values("sequence_name") + index = index.loc[sequence_values.isin(sequences)] + logger.info(f" -> retained {len(index)} samples.") + + pick_frames_criteria = [] + if self.remove_empty_masks: + logger.info("Culling samples with empty masks.") + + if len(index) > self.remove_empty_masks_poll_whole_table_threshold: + # APPROACH 1: find empty masks and drop indices. + # dev load: 17s / 15 s (3.1M / 500K) + stmt = sa.select( + self.frame_annotations_type.sequence_name, + self.frame_annotations_type.frame_number, + ).where(self.frame_annotations_type._mask_mass == 0) + with Session(self._sql_engine) as session: + to_remove = session.execute(stmt).all() + + # Pandas uses np.int64 for integer types, so we have to case + # we might want to read it to pandas DataFrame directly to avoid the loop + to_remove = [(seq, np.int64(fr)) for seq, fr in to_remove] + index.drop(to_remove, errors="ignore", inplace=True) + else: + # APPROACH 3: load index into a temp table and join with annotations + # dev load: 94 s / 23 s (3.1M / 500K) + pick_frames_criteria.append( + sa.or_( + self.frame_annotations_type._mask_mass.is_(None), + self.frame_annotations_type._mask_mass != 0, + ) + ) + + if self.pick_frames_sql_clause: + logger.info("Applying the custom SQL clause.") + pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause)) + + if pick_frames_criteria: + index = self._pick_frames_by_criteria(index, pick_frames_criteria) + + logger.info(f" -> retained {len(index)} samples.") + + if self.limit_to > 0: + logger.info(f"Limiting dataset to first {self.limit_to} frames") + index = index.sort_index().iloc[: self.limit_to] + + return index.reset_index() + + def _pick_frames_by_criteria(self, index: pd.DataFrame, criteria) -> pd.DataFrame: + IndexTable = self._get_temp_index_table_instance() + with self._sql_engine.connect() as connection: + IndexTable.create(connection) + # we don’t let pandas’s `to_sql` create the table automatically as + # the table would be permanent, so we create it and append with pandas + n_rows = index.to_sql(IndexTable.name, connection, if_exists="append") + assert n_rows == len(index) + sa_type = self.frame_annotations_type + stmt = ( + sa.select(IndexTable) + .select_from( + IndexTable.join( + self.frame_annotations_type, + sa.and_( + sa_type.sequence_name == IndexTable.c.sequence_name, + sa_type.frame_number == IndexTable.c.frame_number, + ), + ) + ) + .where(*criteria) + ) + return pd.read_sql_query(stmt, connection).set_index( + ["sequence_name", "frame_number"] + ) + + def _build_index_from_db(self, sequences: Optional[pd.Series]): + logger.info("Loading sequcence-frame index from the database") + stmt = sa.select( + self.frame_annotations_type.sequence_name, + self.frame_annotations_type.frame_number, + self.frame_annotations_type._image_path, + sa.null().label("subset"), + ) + where_conditions = [] + if sequences is not None: + logger.info(" applying filtered sequences") + where_conditions.append( + self.frame_annotations_type.sequence_name.in_(sequences.tolist()) + ) + + if self.remove_empty_masks: + logger.info(" excluding samples with empty masks") + where_conditions.append( + sa.or_( + self.frame_annotations_type._mask_mass.is_(None), + self.frame_annotations_type._mask_mass != 0, + ) + ) + + if self.pick_frames_sql_clause: + logger.info(" applying custom SQL clause") + where_conditions.append(sa.text(self.pick_frames_sql_clause)) + + if where_conditions: + stmt = stmt.where(*where_conditions) + + if self.limit_to > 0: + logger.info(f"Limiting dataset to first {self.limit_to} frames") + stmt = stmt.order_by( + self.frame_annotations_type.sequence_name, + self.frame_annotations_type.frame_number, + ).limit(self.limit_to) + + with self._sql_engine.connect() as connection: + index = pd.read_sql_query(stmt, connection) + + logger.info(f" -> loaded {len(index)} samples.") + return index + + def _sort_index_(self, index): + logger.info("Sorting the index by sequence and frame number.") + index.sort_values(["sequence_name", "frame_number"], inplace=True) + logger.info(" -> Done.") + + def _load_filter_eval_batches(self): + assert self.eval_batches_file + logger.info(f"Loading eval batches from {self.eval_batches_file}") + + if not os.path.isfile(self.eval_batches_file): + # The batch indices file does not exist. + # Most probably the user has not specified the root folder. + raise ValueError( + f"Looking for dataset json file in {self.eval_batches_file}. " + + "Please specify a correct dataset_root folder." + ) + + with open(self.eval_batches_file, "r") as f: + eval_batches = json.load(f) + + # limit the dataset to sequences to allow multiple evaluations in one file + pick_sequences = set(self.pick_sequences) + if self.pick_categories: + cat_to_seq = self.category_to_sequence_names() + pick_sequences.update( + seq for cat in self.pick_categories for seq in cat_to_seq[cat] + ) + + if pick_sequences: + old_len = len(eval_batches) + eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences] + logger.warn( + f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}" + ) + + if self.exclude_sequences: + old_len = len(eval_batches) + exclude_sequences = set(self.exclude_sequences) + eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences] + logger.warn( + f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}" + ) + + return eval_batches + + def _stratified_sample_index(self, index): + # NOTE this stratified sampling can be done more efficiently in + # the no-subset case above if it is added to the SQL query. + # We keep this generic implementation since no-subset case is uncommon + index = index.groupby("sequence_name", group_keys=False).apply( + lambda seq_frames: seq_frames.sample( + min(len(seq_frames), self.n_frames_per_sequence), + random_state=( + _seq_name_to_seed(seq_frames.iloc[0]["sequence_name"]) + self.seed + ), + ) + ) + logger.info(f" -> retained {len(index)} samples aster stratified sampling.") + return index + + def _get_frame_type(self, entry: SqlFrameAnnotation) -> Optional[str]: + return self._index.loc[(entry.sequence_name, entry.frame_number), "subset"] + + def _get_frame_no_coalesced_ts_by_row_indices( + self, + idxs: Sequence[int], + seq_name: Optional[str] = None, + subset_filter: Union[Sequence[str], str, None] = None, + ) -> Tuple[pd.DataFrame, Sequence[int]]: + """ + Loads timestamps for given index rows belonging to the same sequence. + If seq_name is known, it speeds up the computation. + Raises ValueError if `idxs` do not all belong to a single sequences . + """ + index_slice = self._index.iloc[idxs] + if subset_filter is not None: + if isinstance(subset_filter, str): + subset_filter = [subset_filter] + indicator = index_slice["subset"].isin(subset_filter) + index_slice = index_slice.loc[indicator] + idxs = [i for i, isin in zip(idxs, indicator) if isin] + + frames = index_slice.index.get_level_values("frame_number").tolist() + if seq_name is None: + seq_name_list = index_slice.index.get_level_values("sequence_name").tolist() + seq_name_set = set(seq_name_list) + if len(seq_name_set) > 1: + raise ValueError("Given indices belong to more than one sequence.") + elif len(seq_name_set) == 1: + seq_name = seq_name_list[0] + + coalesced_ts = sa.sql.functions.coalesce( + self.frame_annotations_type.frame_timestamp, 0 + ) + stmt = sa.select( + coalesced_ts.label("frame_timestamp"), + self.frame_annotations_type.frame_number, + ).where( + self.frame_annotations_type.sequence_name == seq_name, + self.frame_annotations_type.frame_number.in_(frames), + ) + + with self._sql_engine.connect() as connection: + frame_no_ts = pd.read_sql_query(stmt, connection) + + if len(frame_no_ts) != len(index_slice): + raise ValueError( + "Not all indices are found in the database; " + "do they belong to more than one sequence?" + ) + + return frame_no_ts, idxs + + def _local_path(self, path: str) -> str: + if self.path_manager is None: + return path + return self.path_manager.get_local_path(path) + + def _get_temp_index_table_instance(self, table_name: str = "__index"): + CachedTable = self.frame_annotations_type.metadata.tables.get(table_name) + if CachedTable is not None: # table definition is not idempotent + return CachedTable + + return sa.Table( + table_name, + self.frame_annotations_type.metadata, + sa.Column("sequence_name", sa.String, primary_key=True), + sa.Column("frame_number", sa.Integer, primary_key=True), + sa.Column("_image_path", sa.String), + sa.Column("subset", sa.String), + prefixes=["TEMP"], # NOTE SQLite specific! + ) + + +def _seq_name_to_seed(seq_name) -> int: + """Generates numbers in [0, 2 ** 28)""" + return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16) + + +def _safe_as_tensor(data, dtype): + return torch.tensor(data, dtype=dtype) if data is not None else None diff --git a/pytorch3d/implicitron/dataset/sql_dataset_provider.py b/pytorch3d/implicitron/dataset/sql_dataset_provider.py new file mode 100644 index 00000000..ab161e8d --- /dev/null +++ b/pytorch3d/implicitron/dataset/sql_dataset_provider.py @@ -0,0 +1,424 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +from typing import List, Optional, Tuple, Type + +import numpy as np + +from omegaconf import DictConfig, OmegaConf + +from pytorch3d.implicitron.dataset.dataset_map_provider import ( + DatasetMap, + DatasetMapProviderBase, + PathManagerFactory, +) +from pytorch3d.implicitron.tools.config import ( + expand_args_fields, + registry, + run_auto_creation, +) + +from .sql_dataset import SqlIndexDataset + + +_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "") + +# _NEED_CONTROL is a list of those elements of SqlIndexDataset which +# are not directly specified for it in the config but come from the +# DatasetMapProvider. +_NEED_CONTROL: Tuple[str, ...] = ( + "path_manager", + "subsets", + "sqlite_metadata_file", + "subset_lists_file", +) + +logger = logging.getLogger(__name__) + + +@registry.register +class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] + """ + Generates the training, validation, and testing dataset objects for + a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base. + + The dataset is organized in the filesystem as follows:: + + self.dataset_root + ├── + │ ├── + │ │ ├── depth_masks + │ │ ├── depths + │ │ ├── images + │ │ ├── masks + │ │ └── pointcloud.ply + │ ├── + │ │ ├── depth_masks + │ │ ├── depths + │ │ ├── images + │ │ ├── masks + │ │ └── pointcloud.ply + │ ├── ... + │ ├── + │ ├── set_lists + │ ├── .json + │ ├── .json + │ ├── ... + │ ├── .json + │ ├── eval_batches + │ │ ├── .json + │ │ ├── .json + │ │ ├── ... + │ │ ├── .json + │ ├── frame_annotations.jgz + │ ├── sequence_annotations.jgz + ├── + ├── ... + ├── + ├── set_lists + ├── .sqlite + ├── .sqlite + ├── ... + ├── .sqlite + ├── eval_batches + │ ├── .json + │ ├── .json + │ ├── ... + │ ├── .json + + The dataset contains sequences named `` that may be partitioned by + directories such as `` e.g. representing categories but they + can also be stored in a flat structure. Each sequence folder contains the list of + sequence images, depth maps, foreground masks, and valid-depth masks + `images`, `depths`, `masks`, and `depth_masks` respectively. Furthermore, + `set_lists/` dirtectories (with partitions or global) store json or sqlite files + `.`, each describing a certain sequence subset. + These subset path conventions are not hard-coded and arbitrary relative path can be + specified by setting `self.subset_lists_path` to the relative path w.r.t. + dataset root. + + Each `.json` file contains the following dictionary:: + + { + "train": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + "val": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + "test": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + ] + + defining the list of frames (identified with their `sequence_name` and + `frame_number`) in the "train", "val", and "test" subsets of the dataset. In case of + SQLite format, `.sqlite` contains a table with the header:: + + | sequence_name | frame_number | image_path | subset | + + Note that `frame_number` can be obtained only from the metadata and + does not necesarrily correspond to the numeric suffix of the corresponding image + file name (e.g. a file `//images/frame00005.jpg` can + have its frame number set to `20`, not 5). + + Each `.json` file contains a list of evaluation examples + in the following form:: + + [ + [ # batch 1 + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + [ # batch 2 + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + ] + + Note that the evaluation examples always come from the `"test"` subset of the dataset. + (test frames can repeat across batches). The batches can contain single element, + which is typical in case of regular radiance field fitting. + + Args: + subset_lists_path: The relative path to the dataset subset definition. + For CO3D, these include e.g. "skateboard/set_lists/set_lists_manyview_dev_0.json". + By default (None), dataset is not partitioned to subsets (in that case, setting + `ignore_subsets` will speed up construction) + dataset_root: The root folder of the dataset. + metadata_basename: name of the SQL metadata file in dataset_root; + not expected to be changed by users + test_on_train: Construct validation and test datasets from + the training subset; note that in practice, in this + case all subset dataset objects will be same + only_test_set: Load only the test set. Incompatible with `test_on_train`. + ignore_subsets: Don’t filter by subsets in the dataset; note that in this + case all subset datasets will be same + eval_batch_num_training_frames: Add a certain number of training frames to each + eval batch. Useful for evaluating models that require + source views as input (e.g. NeRF-WCE / PixelNeRF). + dataset_args: Specifies additional arguments to the + JsonIndexDataset constructor call. + path_manager_factory: (Optional) An object that generates an instance of + PathManager that can translate provided file paths. + path_manager_factory_class_type: The class type of `path_manager_factory`. + """ + + category: Optional[str] = None + subset_list_name: Optional[str] = None # TODO: docs + # OR + subset_lists_path: Optional[str] = None + eval_batches_path: Optional[str] = None + + dataset_root: str = _CO3D_SQL_DATASET_ROOT + metadata_basename: str = "metadata.sqlite" + + test_on_train: bool = False + only_test_set: bool = False + ignore_subsets: bool = False + train_subsets: Tuple[str, ...] = ("train",) + val_subsets: Tuple[str, ...] = ("val",) + test_subsets: Tuple[str, ...] = ("test",) + + eval_batch_num_training_frames: int = 0 + + # this is a mould that is never constructed, used to build self._dataset_map values + dataset_class_type: str = "SqlIndexDataset" + dataset: SqlIndexDataset + + path_manager_factory: PathManagerFactory + path_manager_factory_class_type: str = "PathManagerFactory" + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + + if self.only_test_set and self.test_on_train: + raise ValueError("Cannot have only_test_set and test_on_train") + + if self.ignore_subsets and not self.only_test_set: + self.test_on_train = True # no point in loading same data 3 times + + path_manager = self.path_manager_factory.get() + + sqlite_metadata_file = os.path.join(self.dataset_root, self.metadata_basename) + sqlite_metadata_file = _local_path(path_manager, sqlite_metadata_file) + + if not os.path.isfile(sqlite_metadata_file): + # The sqlite_metadata_file does not exist. + # Most probably the user has not specified the root folder. + raise ValueError( + f"Looking for frame annotations in {sqlite_metadata_file}." + + " Please specify a correct dataset_root folder." + + " Note: By default the root folder is taken from the" + + " CO3D_SQL_DATASET_ROOT environment variable." + ) + + if self.subset_lists_path and self.subset_list_name: + raise ValueError( + "subset_lists_path and subset_list_name cannot be both set" + ) + + subset_lists_file = self._get_lists_file("set_lists") + + # setup the common dataset arguments + common_dataset_kwargs = { + **getattr(self, f"dataset_{self.dataset_class_type}_args"), + "sqlite_metadata_file": sqlite_metadata_file, + "dataset_root": self.dataset_root, + "subset_lists_file": subset_lists_file, + "path_manager": path_manager, + } + + if self.category: + logger.info(f"Forcing category filter in the datasets to {self.category}") + common_dataset_kwargs["pick_categories"] = self.category.split(",") + + # get the used dataset type + dataset_type: Type[SqlIndexDataset] = registry.get( + SqlIndexDataset, self.dataset_class_type + ) + expand_args_fields(dataset_type) + + if subset_lists_file is not None and not os.path.isfile(subset_lists_file): + available_subsets = self._get_available_subsets( + OmegaConf.to_object(common_dataset_kwargs["pick_categories"]) + ) + msg = f"Cannot find subset list file {self.subset_lists_path}." + if available_subsets: + msg += f" Some of the available subsets: {str(available_subsets)}." + raise ValueError(msg) + + train_dataset = None + val_dataset = None + if not self.only_test_set: + # load the training set + logger.debug("Constructing train dataset.") + train_dataset = dataset_type( + **common_dataset_kwargs, subsets=self._get_subsets(self.train_subsets) + ) + logger.info(f"Train dataset: {str(train_dataset)}") + + if self.test_on_train: + assert train_dataset is not None + val_dataset = test_dataset = train_dataset + else: + # load the val and test sets + if not self.only_test_set: + # NOTE: this is always loaded in JsonProviderV2 + logger.debug("Extracting val dataset.") + val_dataset = dataset_type( + **common_dataset_kwargs, subsets=self._get_subsets(self.val_subsets) + ) + logger.info(f"Val dataset: {str(val_dataset)}") + + logger.debug("Extracting test dataset.") + eval_batches_file = self._get_lists_file("eval_batches") + del common_dataset_kwargs["eval_batches_file"] + test_dataset = dataset_type( + **common_dataset_kwargs, + subsets=self._get_subsets(self.test_subsets, True), + eval_batches_file=eval_batches_file, + ) + logger.info(f"Test dataset: {str(test_dataset)}") + + if ( + eval_batches_file is not None + and self.eval_batch_num_training_frames > 0 + ): + self._extend_eval_batches(test_dataset) + + self._dataset_map = DatasetMap( + train=train_dataset, val=val_dataset, test=test_dataset + ) + + def _get_subsets(self, subsets, is_eval: bool = False): + if self.ignore_subsets: + return None + + if is_eval and self.eval_batch_num_training_frames > 0: + # we will need to have training frames for extended batches + return list(subsets) + list(self.train_subsets) + + return subsets + + def _extend_eval_batches(self, test_dataset: SqlIndexDataset) -> None: + rng = np.random.default_rng(seed=0) + eval_batches = test_dataset.get_eval_batches() + if eval_batches is None: + raise ValueError("Eval batches were not loaded!") + + for batch in eval_batches: + sequence = batch[0][0] + seq_frames = list( + test_dataset.sequence_frames_in_order(sequence, self.train_subsets) + ) + idx_to_add = rng.permutation(len(seq_frames))[ + : self.eval_batch_num_training_frames + ] + batch.extend((sequence, seq_frames[a][1]) for a in idx_to_add) + + @classmethod + def dataset_tweak_args(cls, type, args: DictConfig) -> None: + """ + Called by get_default_args. + Certain fields are not exposed on each dataset class + but rather are controlled by this provider class. + """ + for key in _NEED_CONTROL: + del args[key] + + def create_dataset(self): + # No `dataset` member of this class is created. + # The dataset(s) live in `self.get_dataset_map`. + pass + + def get_dataset_map(self) -> DatasetMap: + return self._dataset_map # pyre-ignore [16] + + def _get_available_subsets(self, categories: List[str]): + """ + Get the available subset names for a given category folder (if given) inside + a root dataset folder `dataset_root`. + """ + path_manager = self.path_manager_factory.get() + + subsets: List[str] = [] + for prefix in [""] + categories: + set_list_dir = os.path.join(self.dataset_root, prefix, "set_lists") + if not ( + (path_manager is not None) and path_manager.isdir(set_list_dir) + ) and not os.path.isdir(set_list_dir): + continue + + set_list_files = (os.listdir if path_manager is None else path_manager.ls)( + set_list_dir + ) + subsets.extend(os.path.join(prefix, "set_lists", f) for f in set_list_files) + + return subsets + + def _get_lists_file(self, flavor: str) -> Optional[str]: + if flavor == "eval_batches": + subset_lists_path = self.eval_batches_path + else: + subset_lists_path = self.subset_lists_path + + if not subset_lists_path and not self.subset_list_name: + return None + + category_elem = "" + if self.category and "," not in self.category: + # if multiple categories are given, looking for global set lists + category_elem = self.category + + subset_lists_path = subset_lists_path or ( + os.path.join( + category_elem, f"{flavor}", f"{flavor}_{self.subset_list_name}" + ) + ) + + assert subset_lists_path + path_manager = self.path_manager_factory.get() + # try absolute path first + subset_lists_file = _get_local_path_check_extensions( + subset_lists_path, path_manager + ) + if subset_lists_file: + return subset_lists_file + + full_path = os.path.join(self.dataset_root, subset_lists_path) + subset_lists_file = _get_local_path_check_extensions(full_path, path_manager) + + if not subset_lists_file: + raise FileNotFoundError( + f"Subset lists path given but not found: {full_path}" + ) + + return subset_lists_file + + +def _get_local_path_check_extensions( + path, path_manager, extensions=("", ".sqlite", ".json") +) -> Optional[str]: + for ext in extensions: + local = _local_path(path_manager, path + ext) + if os.path.isfile(local): + return local + + return None + + +def _local_path(path_manager, path: str) -> str: + if path_manager is None: + return path + return path_manager.get_local_path(path) diff --git a/pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py b/pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py new file mode 100644 index 00000000..4640feb2 --- /dev/null +++ b/pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any, Dict, Optional, Tuple + +from pytorch3d.implicitron.dataset.data_loader_map_provider import ( + DataLoaderMap, + SceneBatchSampler, + SequenceDataLoaderMapProvider, +) +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase +from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap +from pytorch3d.implicitron.dataset.frame_data import FrameData +from pytorch3d.implicitron.tools.config import registry, run_auto_creation + +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + +# TODO: we can merge it with SequenceDataLoaderMapProvider in PyTorch3D +# and support both eval_batches protocols +@registry.register +class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider): + """ + Implementation of DataLoaderMapProviderBase that may use internal eval batches for + the test dataset. In particular, if `eval_batches_relpath` is set, it loads + eval batches from that json file, otherwise test set is treated in the same way as + train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type` + are respected. + + If conditioning is not required, then the batch size should + be set as 1, and most of the fields do not matter. + + If conditioning is required, each batch will contain one main + frame first to predict and the, rest of the elements are for + conditioning. + + If images_per_seq_options is left empty, the conditioning + frames are picked according to the conditioning type given. + This does not have regard to the order of frames in a + scene, or which frames belong to what scene. + + If images_per_seq_options is given, then the conditioning types + must be SAME and the remaining fields are used. + + Members: + batch_size: The size of the batch of the data loader. + num_workers: Number of data-loading threads in each data loader. + dataset_length_train: The number of batches in a training epoch. Or 0 to mean + an epoch is the length of the training set. + dataset_length_val: The number of batches in a validation epoch. Or 0 to mean + an epoch is the length of the validation set. + dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of + batches in a testing epoch. Or 0 to mean an epoch is the length of the test + set. + images_per_seq_options: Possible numbers of frames sampled per sequence in a batch. + If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial + value. Empty (the default) means that we are not careful about which frames + come from which scene. + sample_consecutive_frames: if True, will sample a contiguous interval of frames + in the sequence. It first sorts the frames by timestimps when available, + otherwise by frame numbers, finds the connected segments within the sequence + of sufficient length, then samples a random pivot element among them and + ideally uses it as a middle of the temporal window, shifting the borders + where necessary. This strategy mitigates the bias against shorter segments + and their boundaries. + consecutive_frames_max_gap: if a number > 0, then used to define the maximum + difference in frame_number of neighbouring frames when forming connected + segments; if both this and consecutive_frames_max_gap_seconds are 0s, + the whole sequence is considered a segment regardless of frame numbers. + consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the + maximum difference in frame_timestamp of neighbouring frames when forming + connected segments; if both this and consecutive_frames_max_gap are 0s, + the whole sequence is considered a segment regardless of frame timestamps. + """ + + batch_size: int = 1 + num_workers: int = 0 + + dataset_length_train: int = 0 + dataset_length_val: int = 0 + dataset_length_test: int = 0 + + images_per_seq_options: Tuple[int, ...] = () + sample_consecutive_frames: bool = False + consecutive_frames_max_gap: int = 0 + consecutive_frames_max_gap_seconds: float = 0.1 + + def __post_init__(self): + run_auto_creation(self) + + def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap: + """ + Returns a collection of data loaders for a given collection of datasets. + """ + train = self._make_generic_data_loader( + datasets.train, + self.dataset_length_train, + datasets.train, + ) + + val = self._make_generic_data_loader( + datasets.val, + self.dataset_length_val, + datasets.train, + ) + + if datasets.test is not None and datasets.test.get_eval_batches() is not None: + test = self._make_eval_data_loader(datasets.test) + else: + test = self._make_generic_data_loader( + datasets.test, + self.dataset_length_test, + datasets.train, + ) + + return DataLoaderMap(train=train, val=val, test=test) + + def _make_eval_data_loader( + self, + dataset: Optional[DatasetBase], + ) -> Optional[DataLoader[FrameData]]: + if dataset is None: + return None + + return DataLoader( + dataset, + batch_sampler=dataset.get_eval_batches(), + **self._get_data_loader_common_kwargs(dataset), + ) + + def _make_generic_data_loader( + self, + dataset: Optional[DatasetBase], + num_batches: int, + train_dataset: Optional[DatasetBase], + ) -> Optional[DataLoader[FrameData]]: + """ + Returns the dataloader for a dataset. + + Args: + dataset: the dataset + num_batches: possible ceiling on number of batches per epoch + train_dataset: the training dataset, used if conditioning_type==TRAIN + conditioning_type: source for padding of batches + """ + if dataset is None: + return None + + data_loader_kwargs = self._get_data_loader_common_kwargs(dataset) + + if len(self.images_per_seq_options) > 0: + # this is a typical few-view setup + # conditioning comes from the same subset since subsets are split by seqs + batch_sampler = SceneBatchSampler( + dataset, + self.batch_size, + num_batches=len(dataset) if num_batches <= 0 else num_batches, + images_per_seq_options=self.images_per_seq_options, + sample_consecutive_frames=self.sample_consecutive_frames, + consecutive_frames_max_gap=self.consecutive_frames_max_gap, + consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds, + ) + return DataLoader( + dataset, + batch_sampler=batch_sampler, + **data_loader_kwargs, + ) + + if self.batch_size == 1: + # this is a typical many-view setup (without conditioning) + return self._simple_loader(dataset, num_batches, data_loader_kwargs) + + # edge case: conditioning on train subset, typical for Nerformer-like many-view + # there is only one sequence in all datasets, so we condition on another subset + return self._train_loader( + dataset, train_dataset, num_batches, data_loader_kwargs + ) + + def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]: + return { + "num_workers": self.num_workers, + "collate_fn": dataset.frame_data_type.collate, + } diff --git a/setup.py b/setup.py index 357073ab..54e6283f 100755 --- a/setup.py +++ b/setup.py @@ -164,6 +164,7 @@ setup( "tqdm>4.29.0", "matplotlib", "accelerate", + "sqlalchemy>=2.0", ], }, entry_points={ diff --git a/tests/implicitron/data/sql_dataset/set_lists_100.json b/tests/implicitron/data/sql_dataset/set_lists_100.json new file mode 100644 index 00000000..96dbe2b4 --- /dev/null +++ b/tests/implicitron/data/sql_dataset/set_lists_100.json @@ -0,0 +1 @@ +{"train": [["cat0_seq0", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000000.jpg"], ["cat0_seq0", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000002.jpg"], ["cat0_seq0", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000004.jpg"], ["cat0_seq0", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000006.jpg"], ["cat0_seq0", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000008.jpg"], ["cat0_seq1", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000000.jpg"], ["cat0_seq1", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000002.jpg"], ["cat0_seq1", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000004.jpg"], ["cat0_seq1", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000006.jpg"], ["cat0_seq1", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000008.jpg"], ["cat0_seq2", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000000.jpg"], ["cat0_seq2", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000002.jpg"], ["cat0_seq2", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000004.jpg"], ["cat0_seq2", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000006.jpg"], ["cat0_seq2", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000008.jpg"], ["cat0_seq3", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000000.jpg"], ["cat0_seq3", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000002.jpg"], ["cat0_seq3", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000004.jpg"], ["cat0_seq3", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000006.jpg"], ["cat0_seq3", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000008.jpg"], ["cat0_seq4", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000000.jpg"], ["cat0_seq4", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000002.jpg"], ["cat0_seq4", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000004.jpg"], ["cat0_seq4", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000006.jpg"], ["cat0_seq4", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000008.jpg"], ["cat1_seq0", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000000.jpg"], ["cat1_seq0", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000002.jpg"], ["cat1_seq0", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000004.jpg"], ["cat1_seq0", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000006.jpg"], ["cat1_seq0", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000008.jpg"], ["cat1_seq1", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000000.jpg"], ["cat1_seq1", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000002.jpg"], ["cat1_seq1", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000004.jpg"], ["cat1_seq1", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000006.jpg"], ["cat1_seq1", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000008.jpg"], ["cat1_seq2", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000000.jpg"], ["cat1_seq2", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000002.jpg"], ["cat1_seq2", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000004.jpg"], ["cat1_seq2", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000006.jpg"], ["cat1_seq2", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000008.jpg"], ["cat1_seq3", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000000.jpg"], ["cat1_seq3", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000002.jpg"], ["cat1_seq3", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000004.jpg"], ["cat1_seq3", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000006.jpg"], ["cat1_seq3", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000008.jpg"], ["cat1_seq4", 0, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000000.jpg"], ["cat1_seq4", 2, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000002.jpg"], ["cat1_seq4", 4, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000004.jpg"], ["cat1_seq4", 6, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000006.jpg"], ["cat1_seq4", 8, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000008.jpg"]], "test": [["cat0_seq0", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000001.jpg"], ["cat0_seq0", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000003.jpg"], ["cat0_seq0", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000005.jpg"], ["cat0_seq0", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000007.jpg"], ["cat0_seq0", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq0/frame000009.jpg"], ["cat0_seq1", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000001.jpg"], ["cat0_seq1", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000003.jpg"], ["cat0_seq1", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000005.jpg"], ["cat0_seq1", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000007.jpg"], ["cat0_seq1", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq1/frame000009.jpg"], ["cat0_seq2", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000001.jpg"], ["cat0_seq2", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000003.jpg"], ["cat0_seq2", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000005.jpg"], ["cat0_seq2", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000007.jpg"], ["cat0_seq2", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq2/frame000009.jpg"], ["cat0_seq3", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000001.jpg"], ["cat0_seq3", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000003.jpg"], ["cat0_seq3", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000005.jpg"], ["cat0_seq3", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000007.jpg"], ["cat0_seq3", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq3/frame000009.jpg"], ["cat0_seq4", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000001.jpg"], ["cat0_seq4", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000003.jpg"], ["cat0_seq4", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000005.jpg"], ["cat0_seq4", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000007.jpg"], ["cat0_seq4", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat0_seq4/frame000009.jpg"], ["cat1_seq0", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000001.jpg"], ["cat1_seq0", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000003.jpg"], ["cat1_seq0", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000005.jpg"], ["cat1_seq0", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000007.jpg"], ["cat1_seq0", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq0/frame000009.jpg"], ["cat1_seq1", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000001.jpg"], ["cat1_seq1", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000003.jpg"], ["cat1_seq1", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000005.jpg"], ["cat1_seq1", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000007.jpg"], ["cat1_seq1", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq1/frame000009.jpg"], ["cat1_seq2", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000001.jpg"], ["cat1_seq2", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000003.jpg"], ["cat1_seq2", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000005.jpg"], ["cat1_seq2", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000007.jpg"], ["cat1_seq2", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq2/frame000009.jpg"], ["cat1_seq3", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000001.jpg"], ["cat1_seq3", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000003.jpg"], ["cat1_seq3", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000005.jpg"], ["cat1_seq3", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000007.jpg"], ["cat1_seq3", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq3/frame000009.jpg"], ["cat1_seq4", 1, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000001.jpg"], ["cat1_seq4", 3, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000003.jpg"], ["cat1_seq4", 5, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000005.jpg"], ["cat1_seq4", 7, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000007.jpg"], ["cat1_seq4", 9, "kfcdtsiagiruwsuplqemogkmqyqhfvpwbvdrikpjlnegagzxhwxrguehparmirtk/cat1_seq4/frame000009.jpg"]]} \ No newline at end of file diff --git a/tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite b/tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite new file mode 100644 index 00000000..2d8ea3b2 Binary files /dev/null and b/tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite differ diff --git a/tests/implicitron/test_co3d_sql.py b/tests/implicitron/test_co3d_sql.py new file mode 100644 index 00000000..7f873cfc --- /dev/null +++ b/tests/implicitron/test_co3d_sql.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import unittest + +import torch + +from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa + SequenceDataLoaderMapProvider, + SimpleDataLoaderMapProvider, +) +from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource +from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa +from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa + SqlIndexDatasetMapProvider, +) +from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import ( + TrainEvalDataLoaderMapProvider, +) +from pytorch3d.implicitron.tools.config import get_default_args + +logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset") +sh = logging.StreamHandler() +logger.addHandler(sh) +logger.setLevel(logging.DEBUG) + +_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "") + + +@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available") +class TestCo3dSqlDataSource(unittest.TestCase): + def test_no_subsets(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.ignore_subsets = True + + dataset_args = provider_args.dataset_SqlIndexDataset_args + dataset_args.pick_categories = ["skateboard"] + dataset_args.limit_sequences_to = 1 + + data_source = ImplicitronDataSource(**args) + self.assertIsInstance( + data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider + ) + _, data_loaders = data_source.get_datasets_and_dataloaders() + self.assertEqual(len(data_loaders.train), 202) + for frame in data_loaders.train: + self.assertIsNone(frame.frame_type) + self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs + break + + def test_subsets(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.subset_lists_path = ( + "skateboard/set_lists/set_lists_manyview_dev_0.json" + ) + # this will naturally limit to one sequence (no need to limit by cat/sequence) + + dataset_args = provider_args.dataset_SqlIndexDataset_args + dataset_args.remove_empty_masks = True + + for sampler_type in [ + "SimpleDataLoaderMapProvider", + "SequenceDataLoaderMapProvider", + "TrainEvalDataLoaderMapProvider", + ]: + args.data_loader_map_provider_class_type = sampler_type + data_source = ImplicitronDataSource(**args) + _, data_loaders = data_source.get_datasets_and_dataloaders() + self.assertEqual(len(data_loaders.train), 102) + self.assertEqual(len(data_loaders.val), 100) + self.assertEqual(len(data_loaders.test), 100) + for split in ["train", "val", "test"]: + for frame in data_loaders[split]: + self.assertEqual(frame.frame_type, [split]) + # check loading blobs + self.assertEqual(frame.image_rgb.shape[-1], 800) + break + + def test_sql_subsets(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite" + + dataset_args = provider_args.dataset_SqlIndexDataset_args + dataset_args.remove_empty_masks = True + dataset_args.pick_categories = ["skateboard"] + + for sampler_type in [ + "SimpleDataLoaderMapProvider", + "SequenceDataLoaderMapProvider", + "TrainEvalDataLoaderMapProvider", + ]: + args.data_loader_map_provider_class_type = sampler_type + data_source = ImplicitronDataSource(**args) + _, data_loaders = data_source.get_datasets_and_dataloaders() + self.assertEqual(len(data_loaders.train), 102) + self.assertEqual(len(data_loaders.val), 100) + self.assertEqual(len(data_loaders.test), 100) + for split in ["train", "val", "test"]: + for frame in data_loaders[split]: + self.assertEqual(frame.frame_type, [split]) + self.assertEqual( + frame.image_rgb.shape[-1], 800 + ) # check loading blobs + break + + @unittest.skip("It takes 75 seconds; skipping by default") + def test_huge_subsets(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite" + + dataset_args = provider_args.dataset_SqlIndexDataset_args + dataset_args.remove_empty_masks = True + + data_source = ImplicitronDataSource(**args) + _, data_loaders = data_source.get_datasets_and_dataloaders() + self.assertEqual(len(data_loaders.train), 3158974) + self.assertEqual(len(data_loaders.val), 518417) + self.assertEqual(len(data_loaders.test), 518417) + for split in ["train", "val", "test"]: + for frame in data_loaders[split]: + self.assertEqual(frame.frame_type, [split]) + self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs + break + + def test_broken_subsets(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.subset_lists_path = "et_non_est" + provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"] + with self.assertRaises(FileNotFoundError) as err: + ImplicitronDataSource(**args) + + # check the hint text + self.assertIn("Subset lists path given but not found", str(err.exception)) + + def test_eval_batches(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite" + provider_args.eval_batches_path = ( + "skateboard/eval_batches/eval_batches_manyview_dev_0.json" + ) + + dataset_args = provider_args.dataset_SqlIndexDataset_args + dataset_args.remove_empty_masks = True + dataset_args.pick_categories = ["skateboard"] + + data_source = ImplicitronDataSource(**args) + _, data_loaders = data_source.get_datasets_and_dataloaders() + self.assertEqual(len(data_loaders.train), 102) + self.assertEqual(len(data_loaders.val), 100) + self.assertEqual(len(data_loaders.test), 50) + for split in ["train", "val", "test"]: + for frame in data_loaders[split]: + self.assertEqual(frame.frame_type, [split]) + self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs + break + + def test_eval_batches_from_subset_list_name(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.subset_list_name = "manyview_dev_0" + provider_args.category = "skateboard" + + dataset_args = provider_args.dataset_SqlIndexDataset_args + dataset_args.remove_empty_masks = True + + data_source = ImplicitronDataSource(**args) + dataset, data_loaders = data_source.get_datasets_and_dataloaders() + self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"]) + self.assertEqual(len(data_loaders.train), 102) + self.assertEqual(len(data_loaders.val), 100) + self.assertEqual(len(data_loaders.test), 50) + for split in ["train", "val", "test"]: + for frame in data_loaders[split]: + self.assertEqual(frame.frame_type, [split]) + self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs + break + + def test_frame_access(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" + provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args + provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite" + + dataset_args = provider_args.dataset_SqlIndexDataset_args + dataset_args.remove_empty_masks = True + dataset_args.pick_categories = ["skateboard"] + frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args + frame_builder_args.load_point_clouds = True + frame_builder_args.box_crop = False # required for .meta + + data_source = ImplicitronDataSource(**args) + dataset_map, _ = data_source.get_datasets_and_dataloaders() + dataset = dataset_map["train"] + + for idx in [10, ("245_26182_52130", 22)]: + example_meta = dataset.meta[idx] + example = dataset[idx] + + self.assertIsNone(example_meta.image_rgb) + self.assertIsNone(example_meta.fg_probability) + self.assertIsNone(example_meta.depth_map) + self.assertIsNone(example_meta.sequence_point_cloud) + self.assertIsNotNone(example_meta.camera) + + self.assertIsNotNone(example.image_rgb) + self.assertIsNotNone(example.fg_probability) + self.assertIsNotNone(example.depth_map) + self.assertIsNotNone(example.sequence_point_cloud) + self.assertIsNotNone(example.camera) + + self.assertEqual(example_meta.sequence_name, example.sequence_name) + self.assertEqual(example_meta.frame_number, example.frame_number) + self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp) + self.assertEqual(example_meta.sequence_category, example.sequence_category) + torch.testing.assert_close(example_meta.camera.R, example.camera.R) + torch.testing.assert_close(example_meta.camera.T, example.camera.T) + torch.testing.assert_close( + example_meta.camera.focal_length, example.camera.focal_length + ) + torch.testing.assert_close( + example_meta.camera.principal_point, example.camera.principal_point + ) diff --git a/tests/implicitron/test_sql_dataset.py b/tests/implicitron/test_sql_dataset.py new file mode 100644 index 00000000..fe315a67 --- /dev/null +++ b/tests/implicitron/test_sql_dataset.py @@ -0,0 +1,522 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import unittest +from collections import Counter + +import pkg_resources + +import torch + +from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset + +NO_BLOBS_KWARGS = { + "dataset_root": "", + "load_images": False, + "load_depths": False, + "load_masks": False, + "load_depth_masks": False, + "box_crop": False, +} + +logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset") +sh = logging.StreamHandler() +logger.addHandler(sh) +logger.setLevel(logging.DEBUG) + + +DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset") +METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite") +SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json") + + +class TestSqlDataset(unittest.TestCase): + def test_basic(self, sequence="cat1_seq2", frame_number=4): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 100) + + # check the items are consecutive + past_sequences = set() + last_frame_number = -1 + last_sequence = "" + for i in range(len(dataset)): + item = dataset[i] + + if item.frame_number == 0: + self.assertNotIn(item.sequence_name, past_sequences) + past_sequences.add(item.sequence_name) + last_sequence = item.sequence_name + else: + self.assertEqual(item.sequence_name, last_sequence) + self.assertEqual(item.frame_number, last_frame_number + 1) + + last_frame_number = item.frame_number + + # test indexing + with self.assertRaises(IndexError): + dataset[len(dataset) + 1] + + # test sequence-frame indexing + item = dataset[sequence, frame_number] + self.assertEqual(item.sequence_name, sequence) + self.assertEqual(item.frame_number, frame_number) + + with self.assertRaises(IndexError): + dataset[sequence, 13] + + def test_filter_empty_masks(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 78) + + def test_pick_frames_sql_clause(self): + dataset_no_empty_masks = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0", + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + # check the datasets are equal + self.assertEqual(len(dataset), len(dataset_no_empty_masks)) + for i in range(len(dataset)): + item_nem = dataset_no_empty_masks[i] + item = dataset[i] + self.assertEqual(item_nem.image_path, item.image_path) + + # remove_empty_masks together with the custom criterion + dataset_ts = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + pick_frames_sql_clause="frame_timestamp < 0.15", + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + self.assertEqual(len(dataset_ts), 19) + + def test_limit_categories(self, category="cat0"): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + pick_categories=[category], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 50) + for i in range(len(dataset)): + self.assertEqual(dataset[i].sequence_category, category) + + def test_limit_sequences(self, num_sequences=3): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + limit_sequences_to=num_sequences, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 10 * num_sequences) + + def delist(sequence_name): + return sequence_name if isinstance(sequence_name, str) else sequence_name[0] + + unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))} + self.assertEqual(len(unique_seqs), num_sequences) + + def test_pick_exclude_sequencess(self, sequence="cat1_seq2"): + # pick sequence + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + pick_sequences=[sequence], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 10) + unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))} + self.assertCountEqual(unique_seqs, {sequence}) + + item = dataset[sequence, 0] + self.assertEqual(item.sequence_name, sequence) + self.assertEqual(item.frame_number, 0) + + # exclude sequence + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + exclude_sequences=[sequence], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 90) + unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))} + self.assertNotIn(sequence, unique_seqs) + + with self.assertRaises(IndexError): + dataset[sequence, 0] + + def test_limit_frames(self, num_frames=13): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + limit_to=num_frames, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), num_frames) + unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))} + self.assertEqual(len(unique_seqs), 2) + + # test when the limit is not binding + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + limit_to=1000, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 100) + + def test_limit_frames_per_sequence(self, num_frames=2): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + n_frames_per_sequence=num_frames, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), num_frames * 10) + seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset))) + self.assertEqual(len(seq_counts), 10) + self.assertCountEqual( + set(seq_counts.values()), {2} + ) # all counts are num_frames + + with self.assertRaises(IndexError): + dataset[next(iter(seq_counts)), num_frames + 1] + + # test when the limit is not binding + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + n_frames_per_sequence=13, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + self.assertEqual(len(dataset), 100) + + def test_filter_medley(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + pick_categories=["cat1"], + exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on + limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2" + limit_to=14, # retaining full "cat1_seq1" and 4 from "cat1_seq2" + n_frames_per_sequence=6, # cutting "cat1_seq1" to 6 frames + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + # result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2 + seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset))) + self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"]) + self.assertEqual(seq_counts["cat1_seq1"], 6) + self.assertEqual(seq_counts["cat1_seq2"], 4) + + def test_subsets_trivial(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + subset_lists_file=SET_LIST_FILE, + limit_to=100, # force sorting + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 100) + + # check the items are consecutive + past_sequences = set() + last_frame_number = -1 + last_sequence = "" + for i in range(len(dataset)): + item = dataset[i] + + if item.frame_number == 0: + self.assertNotIn(item.sequence_name, past_sequences) + past_sequences.add(item.sequence_name) + last_sequence = item.sequence_name + else: + self.assertEqual(item.sequence_name, last_sequence) + self.assertEqual(item.frame_number, last_frame_number + 1) + + last_frame_number = item.frame_number + + def test_subsets_filter_empty_masks(self): + # we need to test this case as it uses quite different logic with `df.drop()` + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + subset_lists_file=SET_LIST_FILE, + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 78) + + def test_subsets_pick_frames_sql_clause(self): + dataset_no_empty_masks = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + subset_lists_file=SET_LIST_FILE, + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0", + subset_lists_file=SET_LIST_FILE, + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + # check the datasets are equal + self.assertEqual(len(dataset), len(dataset_no_empty_masks)) + for i in range(len(dataset)): + item_nem = dataset_no_empty_masks[i] + item = dataset[i] + self.assertEqual(item_nem.image_path, item.image_path) + + # remove_empty_masks together with the custom criterion + dataset_ts = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + pick_frames_sql_clause="frame_timestamp < 0.15", + subset_lists_file=SET_LIST_FILE, + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset_ts), 19) + + def test_single_subset(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + subset_lists_file=SET_LIST_FILE, + subsets=["train"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 50) + + with self.assertRaises(IndexError): + dataset[51] + + # check the items are consecutive + past_sequences = set() + last_frame_number = -1 + last_sequence = "" + for i in range(len(dataset)): + item = dataset[i] + + if item.frame_number < 2: + self.assertNotIn(item.sequence_name, past_sequences) + past_sequences.add(item.sequence_name) + last_sequence = item.sequence_name + else: + self.assertEqual(item.sequence_name, last_sequence) + self.assertEqual(item.frame_number, last_frame_number + 2) + + last_frame_number = item.frame_number + + item = dataset[last_sequence, 0] + self.assertEqual(item.sequence_name, last_sequence) + + with self.assertRaises(IndexError): + dataset[last_sequence, 1] + + def test_subset_with_filters(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + subset_lists_file=SET_LIST_FILE, + subsets=["train"], + pick_categories=["cat1"], + exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on + limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2" + limit_to=7, # retaining full train set of "cat1_seq1" and 2 from "cat1_seq2" + n_frames_per_sequence=3, # cutting "cat1_seq1" to 3 frames + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + # result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2 + seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset))) + self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"]) + self.assertEqual(seq_counts["cat1_seq1"], 3) + self.assertEqual(seq_counts["cat1_seq2"], 2) + + def test_visitor(self): + dataset_sorted = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + sequences = dataset_sorted.sequence_names() + i = 0 + for seq in sequences: + last_ts = float("-Inf") + for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq): + self.assertEqual(i, idx) + i += 1 + self.assertGreaterEqual(ts, last_ts) + last_ts = ts + + # test legacy visitor + old_indices = None + for seq in sequences: + last_ts = float("-Inf") + rows = dataset_sorted._index.index.get_loc(seq) + indices = list(range(rows.start or 0, rows.stop, rows.step or 1)) + fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices) + self.assertEqual(len(fn_ts_list), len(indices)) + + if old_indices: + # check raising if we ask for multiple sequences + with self.assertRaises(ValueError): + dataset_sorted.get_frame_numbers_and_timestamps( + indices + old_indices + ) + + old_indices = indices + + def test_visitor_subsets(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + limit_to=100, # force sorting + subset_lists_file=SET_LIST_FILE, + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + sequences = dataset.sequence_names() + i = 0 + for seq in sequences: + last_ts = float("-Inf") + seq_frames = list(dataset.sequence_frames_in_order(seq)) + self.assertEqual(len(seq_frames), 10) + for ts, _, idx in seq_frames: + self.assertEqual(i, idx) + i += 1 + self.assertGreaterEqual(ts, last_ts) + last_ts = ts + + last_ts = float("-Inf") + train_frames = list(dataset.sequence_frames_in_order(seq, "train")) + self.assertEqual(len(train_frames), 5) + for ts, _, _ in train_frames: + self.assertGreaterEqual(ts, last_ts) + last_ts = ts + + def test_category_to_sequence_names(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + subset_lists_file=SET_LIST_FILE, + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + cat_to_seqs = dataset.category_to_sequence_names() + self.assertEqual(len(cat_to_seqs), 2) + self.assertIn("cat1", cat_to_seqs) + self.assertEqual(len(cat_to_seqs["cat1"]), 5) + + # check that override preserves the behavior + cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names() + self.assertDictEqual(cat_to_seqs, cat_to_seqs_base) + + def test_category_to_sequence_names_filters(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=True, + subset_lists_file=SET_LIST_FILE, + exclude_sequences=["cat1_seq0"], + subsets=["train", "test"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + cat_to_seqs = dataset.category_to_sequence_names() + self.assertEqual(len(cat_to_seqs), 2) + self.assertIn("cat1", cat_to_seqs) + self.assertEqual(len(cat_to_seqs["cat1"]), 4) # minus one + + # check that override preserves the behavior + cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names() + self.assertDictEqual(cat_to_seqs, cat_to_seqs_base) + + def test_meta_access(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + subset_lists_file=SET_LIST_FILE, + subsets=["train"], + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 50) + + for idx in [10, ("cat0_seq2", 2)]: + example_meta = dataset.meta[idx] + example = dataset[idx] + self.assertEqual(example_meta.sequence_name, example.sequence_name) + self.assertEqual(example_meta.frame_number, example.frame_number) + self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp) + self.assertEqual(example_meta.sequence_category, example.sequence_category) + torch.testing.assert_close(example_meta.camera.R, example.camera.R) + torch.testing.assert_close(example_meta.camera.T, example.camera.T) + torch.testing.assert_close( + example_meta.camera.focal_length, example.camera.focal_length + ) + torch.testing.assert_close( + example_meta.camera.principal_point, example.camera.principal_point + ) + + def test_meta_access_no_blobs(self): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + subset_lists_file=SET_LIST_FILE, + subsets=["train"], + frame_data_builder_FrameDataBuilder_args={ + "dataset_root": ".", + "box_crop": False, # required by blob-less accessor + }, + ) + + self.assertIsNone(dataset.meta[0].image_rgb) + self.assertIsNone(dataset.meta[0].fg_probability) + self.assertIsNone(dataset.meta[0].depth_map) + self.assertIsNone(dataset.meta[0].sequence_point_cloud) + self.assertIsNotNone(dataset.meta[0].camera)