mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Generalising SqlIndexDataset to support subtypes of SqlSequenceAnnotation
Summary: We did not often extend sequence-level metadata but now for applications like text-to-3D/video, we need to store captions and similar. Reviewed By: bottler Differential Revision: D68269926 fbshipit-source-id: f8af308adce51863d719a335d85cd2558943bd4c
This commit is contained in:
		
							parent
							
								
									699bc671ca
								
							
						
					
					
						commit
						42a4a7d432
					
				@ -108,6 +108,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
 | 
			
		||||
    sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
 | 
			
		||||
        SqlSequenceAnnotation
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    sqlite_metadata_file: str = ""
 | 
			
		||||
    dataset_root: Optional[str] = None
 | 
			
		||||
@ -246,8 +249,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            self.frame_annotations_type.frame_number
 | 
			
		||||
            == int(frame),  # cast from np.int64
 | 
			
		||||
        )
 | 
			
		||||
        seq_stmt = sa.select(SqlSequenceAnnotation).where(
 | 
			
		||||
            SqlSequenceAnnotation.sequence_name == seq
 | 
			
		||||
        seq_stmt = sa.select(self.sequence_annotations_type).where(
 | 
			
		||||
            self.sequence_annotations_type.sequence_name == seq
 | 
			
		||||
        )
 | 
			
		||||
        with Session(self._sql_engine) as session:
 | 
			
		||||
            entry = session.scalars(stmt).one()
 | 
			
		||||
@ -273,9 +276,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    # override
 | 
			
		||||
    def category_to_sequence_names(self) -> Dict[str, List[str]]:
 | 
			
		||||
        stmt = sa.select(
 | 
			
		||||
            SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
 | 
			
		||||
            self.sequence_annotations_type.category,
 | 
			
		||||
            self.sequence_annotations_type.sequence_name,
 | 
			
		||||
        ).where(  # we limit results to sequences that have frames after all filters
 | 
			
		||||
            SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
 | 
			
		||||
            self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
 | 
			
		||||
        )
 | 
			
		||||
        with self._sql_engine.connect() as connection:
 | 
			
		||||
            cat_to_seqs = pd.read_sql(stmt, connection)
 | 
			
		||||
@ -414,14 +418,14 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            return stmt.where(*where_conditions) if where_conditions else stmt
 | 
			
		||||
 | 
			
		||||
        if self.limit_sequences_per_category_to <= 0:
 | 
			
		||||
            stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
 | 
			
		||||
            stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
 | 
			
		||||
        else:
 | 
			
		||||
            subquery = sa.select(
 | 
			
		||||
                SqlSequenceAnnotation.sequence_name,
 | 
			
		||||
                self.sequence_annotations_type.sequence_name,
 | 
			
		||||
                sa.func.row_number()
 | 
			
		||||
                .over(
 | 
			
		||||
                    order_by=sa.text("ROWID"),  # NOTE: ROWID is SQLite-specific
 | 
			
		||||
                    partition_by=SqlSequenceAnnotation.category,
 | 
			
		||||
                    partition_by=self.sequence_annotations_type.category,
 | 
			
		||||
                )
 | 
			
		||||
                .label("row_number"),
 | 
			
		||||
            )
 | 
			
		||||
@ -457,21 +461,23 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Limiting dataset to categories: {self.pick_categories}")
 | 
			
		||||
        return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
 | 
			
		||||
        return [self.sequence_annotations_type.category.in_(self.pick_categories)]
 | 
			
		||||
 | 
			
		||||
    def _get_pick_filters(self) -> List[sa.ColumnElement]:
 | 
			
		||||
        if not self.pick_sequences:
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
 | 
			
		||||
        return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
 | 
			
		||||
        return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
 | 
			
		||||
 | 
			
		||||
    def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
 | 
			
		||||
        if not self.exclude_sequences:
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
 | 
			
		||||
        return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
 | 
			
		||||
        return [
 | 
			
		||||
            self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
 | 
			
		||||
        subsets = self.subsets
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user