Generalising SqlIndexDataset to support subtypes of SqlSequenceAnnotation

Summary: We did not often extend sequence-level metadata but now for applications like text-to-3D/video, we need to store captions and similar.

Reviewed By: bottler

Differential Revision: D68269926

fbshipit-source-id: f8af308adce51863d719a335d85cd2558943bd4c
This commit is contained in:
Roman Shapovalov 2025-01-20 03:39:06 -08:00 committed by Facebook GitHub Bot
parent 699bc671ca
commit 42a4a7d432

View File

@ -108,6 +108,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
""" """
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
SqlSequenceAnnotation
)
sqlite_metadata_file: str = "" sqlite_metadata_file: str = ""
dataset_root: Optional[str] = None dataset_root: Optional[str] = None
@ -246,8 +249,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
self.frame_annotations_type.frame_number self.frame_annotations_type.frame_number
== int(frame), # cast from np.int64 == int(frame), # cast from np.int64
) )
seq_stmt = sa.select(SqlSequenceAnnotation).where( seq_stmt = sa.select(self.sequence_annotations_type).where(
SqlSequenceAnnotation.sequence_name == seq self.sequence_annotations_type.sequence_name == seq
) )
with Session(self._sql_engine) as session: with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one() entry = session.scalars(stmt).one()
@ -273,9 +276,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
# override # override
def category_to_sequence_names(self) -> Dict[str, List[str]]: def category_to_sequence_names(self) -> Dict[str, List[str]]:
stmt = sa.select( stmt = sa.select(
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name self.sequence_annotations_type.category,
self.sequence_annotations_type.sequence_name,
).where( # we limit results to sequences that have frames after all filters ).where( # we limit results to sequences that have frames after all filters
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names()) self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
) )
with self._sql_engine.connect() as connection: with self._sql_engine.connect() as connection:
cat_to_seqs = pd.read_sql(stmt, connection) cat_to_seqs = pd.read_sql(stmt, connection)
@ -414,14 +418,14 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
return stmt.where(*where_conditions) if where_conditions else stmt return stmt.where(*where_conditions) if where_conditions else stmt
if self.limit_sequences_per_category_to <= 0: if self.limit_sequences_per_category_to <= 0:
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name)) stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
else: else:
subquery = sa.select( subquery = sa.select(
SqlSequenceAnnotation.sequence_name, self.sequence_annotations_type.sequence_name,
sa.func.row_number() sa.func.row_number()
.over( .over(
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
partition_by=SqlSequenceAnnotation.category, partition_by=self.sequence_annotations_type.category,
) )
.label("row_number"), .label("row_number"),
) )
@ -457,21 +461,23 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
return [] return []
logger.info(f"Limiting dataset to categories: {self.pick_categories}") logger.info(f"Limiting dataset to categories: {self.pick_categories}")
return [SqlSequenceAnnotation.category.in_(self.pick_categories)] return [self.sequence_annotations_type.category.in_(self.pick_categories)]
def _get_pick_filters(self) -> List[sa.ColumnElement]: def _get_pick_filters(self) -> List[sa.ColumnElement]:
if not self.pick_sequences: if not self.pick_sequences:
return [] return []
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}") logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)] return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
def _get_exclude_filters(self) -> List[sa.ColumnOperators]: def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
if not self.exclude_sequences: if not self.exclude_sequences:
return [] return []
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}") logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)] return [
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
]
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame: def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
subsets = self.subsets subsets = self.subsets