mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Support limiting num sequences per category.
Summary: Adds stratified sampling of sequences within categories applied after category / sequence filters but before the num sequence limit. It respects the insertion order into the sequence_annots table, i.e. takes top N sequences within each category. Reviewed By: bottler Differential Revision: D46724002 fbshipit-source-id: 597cb2a795c3f3bc07f838fc51b4e95a4f981ad3
This commit is contained in:
		
							parent
							
								
									5ffeb4d580
								
							
						
					
					
						commit
						09a99f2e6d
					
				@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        pick_categories: Restrict the dataset to the given list of categories.
 | 
			
		||||
        pick_sequences: A Sequence of sequence names to restrict the dataset to.
 | 
			
		||||
        exclude_sequences: A Sequence of the names of the sequences to exclude.
 | 
			
		||||
        limit_sequences_per_category_to: Limit the dataset to the first up to N
 | 
			
		||||
            sequences within each category (applies after all other sequence filters
 | 
			
		||||
            but before `limit_sequences_to`).
 | 
			
		||||
        limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
 | 
			
		||||
            sequences (after other sequence filters have been applied but before
 | 
			
		||||
            frame-based filters).
 | 
			
		||||
@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
 | 
			
		||||
    pick_sequences: Tuple[str, ...] = ()
 | 
			
		||||
    exclude_sequences: Tuple[str, ...] = ()
 | 
			
		||||
    limit_sequences_per_category_to: int = 0
 | 
			
		||||
    limit_sequences_to: int = 0
 | 
			
		||||
    limit_to: int = 0
 | 
			
		||||
    n_frames_per_sequence: int = -1
 | 
			
		||||
@ -373,6 +377,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
            self.remove_empty_masks
 | 
			
		||||
            or self.limit_to > 0
 | 
			
		||||
            or self.limit_sequences_to > 0
 | 
			
		||||
            or self.limit_sequences_per_category_to > 0
 | 
			
		||||
            or len(self.pick_sequences) > 0
 | 
			
		||||
            or len(self.exclude_sequences) > 0
 | 
			
		||||
            or len(self.pick_categories) > 0
 | 
			
		||||
@ -380,20 +385,38 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
 | 
			
		||||
        # maximum possible query: WHERE category IN 'self.pick_categories'
 | 
			
		||||
        # maximum possible filter (if limit_sequences_per_category_to == 0):
 | 
			
		||||
        # WHERE category IN 'self.pick_categories'
 | 
			
		||||
        # AND sequence_name IN 'self.pick_sequences'
 | 
			
		||||
        # AND sequence_name NOT IN 'self.exclude_sequences'
 | 
			
		||||
        # LIMIT 'self.limit_sequence_to'
 | 
			
		||||
 | 
			
		||||
        stmt = sa.select(SqlSequenceAnnotation.sequence_name)
 | 
			
		||||
 | 
			
		||||
        where_conditions = [
 | 
			
		||||
            *self._get_category_filters(),
 | 
			
		||||
            *self._get_pick_filters(),
 | 
			
		||||
            *self._get_exclude_filters(),
 | 
			
		||||
        ]
 | 
			
		||||
        if where_conditions:
 | 
			
		||||
            stmt = stmt.where(*where_conditions)
 | 
			
		||||
 | 
			
		||||
        def add_where(stmt):
 | 
			
		||||
            return stmt.where(*where_conditions) if where_conditions else stmt
 | 
			
		||||
 | 
			
		||||
        if self.limit_sequences_per_category_to <= 0:
 | 
			
		||||
            stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
 | 
			
		||||
        else:
 | 
			
		||||
            subquery = sa.select(
 | 
			
		||||
                SqlSequenceAnnotation.sequence_name,
 | 
			
		||||
                sa.func.row_number()
 | 
			
		||||
                .over(
 | 
			
		||||
                    order_by=sa.text("ROWID"),  # NOTE: ROWID is SQLite-specific
 | 
			
		||||
                    partition_by=SqlSequenceAnnotation.category,
 | 
			
		||||
                )
 | 
			
		||||
                .label("row_number"),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            subquery = add_where(subquery).subquery()
 | 
			
		||||
            stmt = sa.select(subquery.c.sequence_name).where(
 | 
			
		||||
                subquery.c.row_number <= self.limit_sequences_per_category_to
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.limit_sequences_to > 0:
 | 
			
		||||
            logger.info(
 | 
			
		||||
@ -402,7 +425,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
            # NOTE: ROWID is SQLite-specific
 | 
			
		||||
            stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
 | 
			
		||||
 | 
			
		||||
        if not where_conditions and self.limit_sequences_to <= 0:
 | 
			
		||||
        if (
 | 
			
		||||
            not where_conditions
 | 
			
		||||
            and self.limit_sequences_to <= 0
 | 
			
		||||
            and self.limit_sequences_per_category_to <= 0
 | 
			
		||||
        ):
 | 
			
		||||
            # we will not need to filter by sequences
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -222,6 +222,30 @@ class TestSqlDataset(unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(dataset), 100)
 | 
			
		||||
 | 
			
		||||
    def test_limit_sequence_per_category(self, num_sequences=2):
 | 
			
		||||
        dataset = SqlIndexDataset(
 | 
			
		||||
            sqlite_metadata_file=METADATA_FILE,
 | 
			
		||||
            remove_empty_masks=False,
 | 
			
		||||
            limit_sequences_per_category_to=num_sequences,
 | 
			
		||||
            frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(dataset), num_sequences * 10 * 2)
 | 
			
		||||
        seq_names = list(dataset.sequence_names())
 | 
			
		||||
        self.assertEqual(len(seq_names), num_sequences * 2)
 | 
			
		||||
        # check that we respect the row order
 | 
			
		||||
        for seq_name in seq_names:
 | 
			
		||||
            self.assertLess(int(seq_name[-1]), num_sequences)
 | 
			
		||||
 | 
			
		||||
        # test when the limit is not binding
 | 
			
		||||
        dataset = SqlIndexDataset(
 | 
			
		||||
            sqlite_metadata_file=METADATA_FILE,
 | 
			
		||||
            remove_empty_masks=False,
 | 
			
		||||
            limit_sequences_per_category_to=13,
 | 
			
		||||
            frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(dataset), 100)
 | 
			
		||||
 | 
			
		||||
    def test_filter_medley(self):
 | 
			
		||||
        dataset = SqlIndexDataset(
 | 
			
		||||
            sqlite_metadata_file=METADATA_FILE,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user