mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Support limiting num sequences per category.
Summary: Adds stratified sampling of sequences within categories applied after category / sequence filters but before the num sequence limit. It respects the insertion order into the sequence_annots table, i.e. takes top N sequences within each category. Reviewed By: bottler Differential Revision: D46724002 fbshipit-source-id: 597cb2a795c3f3bc07f838fc51b4e95a4f981ad3
This commit is contained in:
parent
5ffeb4d580
commit
09a99f2e6d
@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
pick_categories: Restrict the dataset to the given list of categories.
|
pick_categories: Restrict the dataset to the given list of categories.
|
||||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||||
|
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||||
|
sequences within each category (applies after all other sequence filters
|
||||||
|
but before `limit_sequences_to`).
|
||||||
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
|
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
|
||||||
sequences (after other sequence filters have been applied but before
|
sequences (after other sequence filters have been applied but before
|
||||||
frame-based filters).
|
frame-based filters).
|
||||||
@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
|
|
||||||
pick_sequences: Tuple[str, ...] = ()
|
pick_sequences: Tuple[str, ...] = ()
|
||||||
exclude_sequences: Tuple[str, ...] = ()
|
exclude_sequences: Tuple[str, ...] = ()
|
||||||
|
limit_sequences_per_category_to: int = 0
|
||||||
limit_sequences_to: int = 0
|
limit_sequences_to: int = 0
|
||||||
limit_to: int = 0
|
limit_to: int = 0
|
||||||
n_frames_per_sequence: int = -1
|
n_frames_per_sequence: int = -1
|
||||||
@ -373,6 +377,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
self.remove_empty_masks
|
self.remove_empty_masks
|
||||||
or self.limit_to > 0
|
or self.limit_to > 0
|
||||||
or self.limit_sequences_to > 0
|
or self.limit_sequences_to > 0
|
||||||
|
or self.limit_sequences_per_category_to > 0
|
||||||
or len(self.pick_sequences) > 0
|
or len(self.pick_sequences) > 0
|
||||||
or len(self.exclude_sequences) > 0
|
or len(self.exclude_sequences) > 0
|
||||||
or len(self.pick_categories) > 0
|
or len(self.pick_categories) > 0
|
||||||
@ -380,20 +385,38 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||||
# maximum possible query: WHERE category IN 'self.pick_categories'
|
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||||
|
# WHERE category IN 'self.pick_categories'
|
||||||
# AND sequence_name IN 'self.pick_sequences'
|
# AND sequence_name IN 'self.pick_sequences'
|
||||||
# AND sequence_name NOT IN 'self.exclude_sequences'
|
# AND sequence_name NOT IN 'self.exclude_sequences'
|
||||||
# LIMIT 'self.limit_sequence_to'
|
# LIMIT 'self.limit_sequence_to'
|
||||||
|
|
||||||
stmt = sa.select(SqlSequenceAnnotation.sequence_name)
|
|
||||||
|
|
||||||
where_conditions = [
|
where_conditions = [
|
||||||
*self._get_category_filters(),
|
*self._get_category_filters(),
|
||||||
*self._get_pick_filters(),
|
*self._get_pick_filters(),
|
||||||
*self._get_exclude_filters(),
|
*self._get_exclude_filters(),
|
||||||
]
|
]
|
||||||
if where_conditions:
|
|
||||||
stmt = stmt.where(*where_conditions)
|
def add_where(stmt):
|
||||||
|
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||||
|
|
||||||
|
if self.limit_sequences_per_category_to <= 0:
|
||||||
|
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
||||||
|
else:
|
||||||
|
subquery = sa.select(
|
||||||
|
SqlSequenceAnnotation.sequence_name,
|
||||||
|
sa.func.row_number()
|
||||||
|
.over(
|
||||||
|
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||||
|
partition_by=SqlSequenceAnnotation.category,
|
||||||
|
)
|
||||||
|
.label("row_number"),
|
||||||
|
)
|
||||||
|
|
||||||
|
subquery = add_where(subquery).subquery()
|
||||||
|
stmt = sa.select(subquery.c.sequence_name).where(
|
||||||
|
subquery.c.row_number <= self.limit_sequences_per_category_to
|
||||||
|
)
|
||||||
|
|
||||||
if self.limit_sequences_to > 0:
|
if self.limit_sequences_to > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -402,7 +425,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
# NOTE: ROWID is SQLite-specific
|
# NOTE: ROWID is SQLite-specific
|
||||||
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
|
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
|
||||||
|
|
||||||
if not where_conditions and self.limit_sequences_to <= 0:
|
if (
|
||||||
|
not where_conditions
|
||||||
|
and self.limit_sequences_to <= 0
|
||||||
|
and self.limit_sequences_per_category_to <= 0
|
||||||
|
):
|
||||||
# we will not need to filter by sequences
|
# we will not need to filter by sequences
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -222,6 +222,30 @@ class TestSqlDataset(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(len(dataset), 100)
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
|
def test_limit_sequence_per_category(self, num_sequences=2):
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
limit_sequences_per_category_to=num_sequences,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), num_sequences * 10 * 2)
|
||||||
|
seq_names = list(dataset.sequence_names())
|
||||||
|
self.assertEqual(len(seq_names), num_sequences * 2)
|
||||||
|
# check that we respect the row order
|
||||||
|
for seq_name in seq_names:
|
||||||
|
self.assertLess(int(seq_name[-1]), num_sequences)
|
||||||
|
|
||||||
|
# test when the limit is not binding
|
||||||
|
dataset = SqlIndexDataset(
|
||||||
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
limit_sequences_per_category_to=13,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
def test_filter_medley(self):
|
def test_filter_medley(self):
|
||||||
dataset = SqlIndexDataset(
|
dataset = SqlIndexDataset(
|
||||||
sqlite_metadata_file=METADATA_FILE,
|
sqlite_metadata_file=METADATA_FILE,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user