mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
Support limiting num sequences per category.
Summary: Adds stratified sampling of sequences within categories applied after category / sequence filters but before the num sequence limit. It respects the insertion order into the sequence_annots table, i.e. takes top N sequences within each category. Reviewed By: bottler Differential Revision: D46724002 fbshipit-source-id: 597cb2a795c3f3bc07f838fc51b4e95a4f981ad3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5ffeb4d580
commit
09a99f2e6d
@@ -222,6 +222,30 @@ class TestSqlDataset(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_limit_sequence_per_category(self, num_sequences=2):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_sequences_per_category_to=num_sequences,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), num_sequences * 10 * 2)
|
||||
seq_names = list(dataset.sequence_names())
|
||||
self.assertEqual(len(seq_names), num_sequences * 2)
|
||||
# check that we respect the row order
|
||||
for seq_name in seq_names:
|
||||
self.assertLess(int(seq_name[-1]), num_sequences)
|
||||
|
||||
# test when the limit is not binding
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_sequences_per_category_to=13,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_filter_medley(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
|
||||
Reference in New Issue
Block a user