Support limiting num sequences per category.

Summary:
Adds stratified sampling of sequences within categories applied after category / sequence filters but before the num sequence limit.
It respects the insertion order into the sequence_annots table, i.e. takes top N sequences within each category.

Reviewed By: bottler

Differential Revision: D46724002

fbshipit-source-id: 597cb2a795c3f3bc07f838fc51b4e95a4f981ad3
This commit is contained in:
Roman Shapovalov
2023-06-14 07:12:02 -07:00
committed by Facebook GitHub Bot
parent 5ffeb4d580
commit 09a99f2e6d
2 changed files with 57 additions and 6 deletions

View File

@@ -222,6 +222,30 @@ class TestSqlDataset(unittest.TestCase):
)
self.assertEqual(len(dataset), 100)
def test_limit_sequence_per_category(self, num_sequences=2):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_sequences_per_category_to=num_sequences,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), num_sequences * 10 * 2)
seq_names = list(dataset.sequence_names())
self.assertEqual(len(seq_names), num_sequences * 2)
# check that we respect the row order
for seq_name in seq_names:
self.assertLess(int(seq_name[-1]), num_sequences)
# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_sequences_per_category_to=13,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
def test_filter_medley(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,