mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Generalising SqlIndexDataset to support subtypes of SqlSequenceAnnotation
Summary: We did not often extend sequence-level metadata but now for applications like text-to-3D/video, we need to store captions and similar. Reviewed By: bottler Differential Revision: D68269926 fbshipit-source-id: f8af308adce51863d719a335d85cd2558943bd4c
This commit is contained in:
parent
699bc671ca
commit
42a4a7d432
@ -108,6 +108,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||||
|
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
|
||||||
|
SqlSequenceAnnotation
|
||||||
|
)
|
||||||
|
|
||||||
sqlite_metadata_file: str = ""
|
sqlite_metadata_file: str = ""
|
||||||
dataset_root: Optional[str] = None
|
dataset_root: Optional[str] = None
|
||||||
@ -246,8 +249,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
self.frame_annotations_type.frame_number
|
self.frame_annotations_type.frame_number
|
||||||
== int(frame), # cast from np.int64
|
== int(frame), # cast from np.int64
|
||||||
)
|
)
|
||||||
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||||
SqlSequenceAnnotation.sequence_name == seq
|
self.sequence_annotations_type.sequence_name == seq
|
||||||
)
|
)
|
||||||
with Session(self._sql_engine) as session:
|
with Session(self._sql_engine) as session:
|
||||||
entry = session.scalars(stmt).one()
|
entry = session.scalars(stmt).one()
|
||||||
@ -273,9 +276,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
# override
|
# override
|
||||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
self.sequence_annotations_type.category,
|
||||||
|
self.sequence_annotations_type.sequence_name,
|
||||||
).where( # we limit results to sequences that have frames after all filters
|
).where( # we limit results to sequences that have frames after all filters
|
||||||
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
|
||||||
)
|
)
|
||||||
with self._sql_engine.connect() as connection:
|
with self._sql_engine.connect() as connection:
|
||||||
cat_to_seqs = pd.read_sql(stmt, connection)
|
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||||
@ -414,14 +418,14 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||||
|
|
||||||
if self.limit_sequences_per_category_to <= 0:
|
if self.limit_sequences_per_category_to <= 0:
|
||||||
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
|
||||||
else:
|
else:
|
||||||
subquery = sa.select(
|
subquery = sa.select(
|
||||||
SqlSequenceAnnotation.sequence_name,
|
self.sequence_annotations_type.sequence_name,
|
||||||
sa.func.row_number()
|
sa.func.row_number()
|
||||||
.over(
|
.over(
|
||||||
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||||
partition_by=SqlSequenceAnnotation.category,
|
partition_by=self.sequence_annotations_type.category,
|
||||||
)
|
)
|
||||||
.label("row_number"),
|
.label("row_number"),
|
||||||
)
|
)
|
||||||
@ -457,21 +461,23 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||||
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
|
||||||
|
|
||||||
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||||
if not self.pick_sequences:
|
if not self.pick_sequences:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||||
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
|
||||||
|
|
||||||
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||||
if not self.exclude_sequences:
|
if not self.exclude_sequences:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||||
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
return [
|
||||||
|
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
|
||||||
|
]
|
||||||
|
|
||||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||||
subsets = self.subsets
|
subsets = self.subsets
|
||||||
|
Loading…
x
Reference in New Issue
Block a user