mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Support limiting num sequences per category.
Summary: Adds stratified sampling of sequences within categories applied after category / sequence filters but before the num sequence limit. It respects the insertion order into the sequence_annots table, i.e. takes top N sequences within each category. Reviewed By: bottler Differential Revision: D46724002 fbshipit-source-id: 597cb2a795c3f3bc07f838fc51b4e95a4f981ad3
This commit is contained in:
parent
5ffeb4d580
commit
09a99f2e6d
@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
pick_categories: Restrict the dataset to the given list of categories.
|
||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||
sequences within each category (applies after all other sequence filters
|
||||
but before `limit_sequences_to`).
|
||||
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
|
||||
sequences (after other sequence filters have been applied but before
|
||||
frame-based filters).
|
||||
@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
|
||||
pick_sequences: Tuple[str, ...] = ()
|
||||
exclude_sequences: Tuple[str, ...] = ()
|
||||
limit_sequences_per_category_to: int = 0
|
||||
limit_sequences_to: int = 0
|
||||
limit_to: int = 0
|
||||
n_frames_per_sequence: int = -1
|
||||
@ -373,6 +377,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
self.remove_empty_masks
|
||||
or self.limit_to > 0
|
||||
or self.limit_sequences_to > 0
|
||||
or self.limit_sequences_per_category_to > 0
|
||||
or len(self.pick_sequences) > 0
|
||||
or len(self.exclude_sequences) > 0
|
||||
or len(self.pick_categories) > 0
|
||||
@ -380,20 +385,38 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
)
|
||||
|
||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||
# maximum possible query: WHERE category IN 'self.pick_categories'
|
||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||
# WHERE category IN 'self.pick_categories'
|
||||
# AND sequence_name IN 'self.pick_sequences'
|
||||
# AND sequence_name NOT IN 'self.exclude_sequences'
|
||||
# LIMIT 'self.limit_sequence_to'
|
||||
|
||||
stmt = sa.select(SqlSequenceAnnotation.sequence_name)
|
||||
|
||||
where_conditions = [
|
||||
*self._get_category_filters(),
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if where_conditions:
|
||||
stmt = stmt.where(*where_conditions)
|
||||
|
||||
def add_where(stmt):
|
||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||
|
||||
if self.limit_sequences_per_category_to <= 0:
|
||||
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
||||
else:
|
||||
subquery = sa.select(
|
||||
SqlSequenceAnnotation.sequence_name,
|
||||
sa.func.row_number()
|
||||
.over(
|
||||
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||
partition_by=SqlSequenceAnnotation.category,
|
||||
)
|
||||
.label("row_number"),
|
||||
)
|
||||
|
||||
subquery = add_where(subquery).subquery()
|
||||
stmt = sa.select(subquery.c.sequence_name).where(
|
||||
subquery.c.row_number <= self.limit_sequences_per_category_to
|
||||
)
|
||||
|
||||
if self.limit_sequences_to > 0:
|
||||
logger.info(
|
||||
@ -402,7 +425,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
# NOTE: ROWID is SQLite-specific
|
||||
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
|
||||
|
||||
if not where_conditions and self.limit_sequences_to <= 0:
|
||||
if (
|
||||
not where_conditions
|
||||
and self.limit_sequences_to <= 0
|
||||
and self.limit_sequences_per_category_to <= 0
|
||||
):
|
||||
# we will not need to filter by sequences
|
||||
return None
|
||||
|
||||
|
@ -222,6 +222,30 @@ class TestSqlDataset(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_limit_sequence_per_category(self, num_sequences=2):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_sequences_per_category_to=num_sequences,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), num_sequences * 10 * 2)
|
||||
seq_names = list(dataset.sequence_names())
|
||||
self.assertEqual(len(seq_names), num_sequences * 2)
|
||||
# check that we respect the row order
|
||||
for seq_name in seq_names:
|
||||
self.assertLess(int(seq_name[-1]), num_sequences)
|
||||
|
||||
# test when the limit is not binding
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_sequences_per_category_to=13,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_filter_medley(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
|
Loading…
x
Reference in New Issue
Block a user