diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py index 70418e97..ea42de43 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset.py +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -108,6 +108,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): """ frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation + sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = ( + SqlSequenceAnnotation + ) sqlite_metadata_file: str = "" dataset_root: Optional[str] = None @@ -246,8 +249,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): self.frame_annotations_type.frame_number == int(frame), # cast from np.int64 ) - seq_stmt = sa.select(SqlSequenceAnnotation).where( - SqlSequenceAnnotation.sequence_name == seq + seq_stmt = sa.select(self.sequence_annotations_type).where( + self.sequence_annotations_type.sequence_name == seq ) with Session(self._sql_engine) as session: entry = session.scalars(stmt).one() @@ -273,9 +276,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # override def category_to_sequence_names(self) -> Dict[str, List[str]]: stmt = sa.select( - SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name + self.sequence_annotations_type.category, + self.sequence_annotations_type.sequence_name, ).where( # we limit results to sequences that have frames after all filters - SqlSequenceAnnotation.sequence_name.in_(self.sequence_names()) + self.sequence_annotations_type.sequence_name.in_(self.sequence_names()) ) with self._sql_engine.connect() as connection: cat_to_seqs = pd.read_sql(stmt, connection) @@ -414,14 +418,14 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): return stmt.where(*where_conditions) if where_conditions else stmt if self.limit_sequences_per_category_to <= 0: - stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name)) + stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name)) else: subquery = sa.select( - SqlSequenceAnnotation.sequence_name, + self.sequence_annotations_type.sequence_name, sa.func.row_number() .over( order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific - partition_by=SqlSequenceAnnotation.category, + partition_by=self.sequence_annotations_type.category, ) .label("row_number"), ) @@ -457,21 +461,23 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): return [] logger.info(f"Limiting dataset to categories: {self.pick_categories}") - return [SqlSequenceAnnotation.category.in_(self.pick_categories)] + return [self.sequence_annotations_type.category.in_(self.pick_categories)] def _get_pick_filters(self) -> List[sa.ColumnElement]: if not self.pick_sequences: return [] logger.info(f"Limiting dataset to sequences: {self.pick_sequences}") - return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)] + return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)] def _get_exclude_filters(self) -> List[sa.ColumnOperators]: if not self.exclude_sequences: return [] logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}") - return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)] + return [ + self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences) + ] def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame: subsets = self.subsets