diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py index 2c74e56c..470f5a95 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset.py +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore pick_categories: Restrict the dataset to the given list of categories. pick_sequences: A Sequence of sequence names to restrict the dataset to. exclude_sequences: A Sequence of the names of the sequences to exclude. + limit_sequences_per_category_to: Limit the dataset to the first up to N + sequences within each category (applies after all other sequence filters + but before `limit_sequences_to`). limit_sequences_to: Limit the dataset to the first `limit_sequences_to` sequences (after other sequence filters have been applied but before frame-based filters). @@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore pick_sequences: Tuple[str, ...] = () exclude_sequences: Tuple[str, ...] = () + limit_sequences_per_category_to: int = 0 limit_sequences_to: int = 0 limit_to: int = 0 n_frames_per_sequence: int = -1 @@ -373,6 +377,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore self.remove_empty_masks or self.limit_to > 0 or self.limit_sequences_to > 0 + or self.limit_sequences_per_category_to > 0 or len(self.pick_sequences) > 0 or len(self.exclude_sequences) > 0 or len(self.pick_categories) > 0 @@ -380,20 +385,38 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore ) def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]: - # maximum possible query: WHERE category IN 'self.pick_categories' + # maximum possible filter (if limit_sequences_per_category_to == 0): + # WHERE category IN 'self.pick_categories' # AND sequence_name IN 'self.pick_sequences' # AND sequence_name NOT IN 'self.exclude_sequences' # LIMIT 'self.limit_sequence_to' - stmt = sa.select(SqlSequenceAnnotation.sequence_name) - where_conditions = [ *self._get_category_filters(), *self._get_pick_filters(), *self._get_exclude_filters(), ] - if where_conditions: - stmt = stmt.where(*where_conditions) + + def add_where(stmt): + return stmt.where(*where_conditions) if where_conditions else stmt + + if self.limit_sequences_per_category_to <= 0: + stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name)) + else: + subquery = sa.select( + SqlSequenceAnnotation.sequence_name, + sa.func.row_number() + .over( + order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific + partition_by=SqlSequenceAnnotation.category, + ) + .label("row_number"), + ) + + subquery = add_where(subquery).subquery() + stmt = sa.select(subquery.c.sequence_name).where( + subquery.c.row_number <= self.limit_sequences_per_category_to + ) if self.limit_sequences_to > 0: logger.info( @@ -402,7 +425,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore # NOTE: ROWID is SQLite-specific stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to) - if not where_conditions and self.limit_sequences_to <= 0: + if ( + not where_conditions + and self.limit_sequences_to <= 0 + and self.limit_sequences_per_category_to <= 0 + ): # we will not need to filter by sequences return None diff --git a/tests/implicitron/test_sql_dataset.py b/tests/implicitron/test_sql_dataset.py index fe315a67..f5baf505 100644 --- a/tests/implicitron/test_sql_dataset.py +++ b/tests/implicitron/test_sql_dataset.py @@ -222,6 +222,30 @@ class TestSqlDataset(unittest.TestCase): ) self.assertEqual(len(dataset), 100) + def test_limit_sequence_per_category(self, num_sequences=2): + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + limit_sequences_per_category_to=num_sequences, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), num_sequences * 10 * 2) + seq_names = list(dataset.sequence_names()) + self.assertEqual(len(seq_names), num_sequences * 2) + # check that we respect the row order + for seq_name in seq_names: + self.assertLess(int(seq_name[-1]), num_sequences) + + # test when the limit is not binding + dataset = SqlIndexDataset( + sqlite_metadata_file=METADATA_FILE, + remove_empty_masks=False, + limit_sequences_per_category_to=13, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + self.assertEqual(len(dataset), 100) + def test_filter_medley(self): dataset = SqlIndexDataset( sqlite_metadata_file=METADATA_FILE,