mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-16 19:30:34 +08:00
Fix for T251460511 ("Your diff, D90498281, broke one test")
Reviewed By: sgrigory Differential Revision: D90649493 fbshipit-source-id: 2a77c45ec8e6e5aa0a20437a765fbb9f0b566406
This commit is contained in:
committed by
meta-codesync[bot]
parent
0c3b204375
commit
5b1cce56bc
@@ -483,9 +483,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if self.pick_sequences_sql_clause:
|
||||
if pick_sequences_sql_clause := self.pick_sequences_sql_clause:
|
||||
print("Applying the custom SQL clause.")
|
||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||
where_conditions.append(sa.text(pick_sequences_sql_clause))
|
||||
|
||||
def add_where(stmt):
|
||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||
@@ -505,6 +506,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
|
||||
subquery = add_where(subquery).subquery()
|
||||
stmt = sa.select(subquery.c.sequence_name).where(
|
||||
# pyre-ignore[6]: SQLAlchemy column comparison returns ColumnElement, not bool
|
||||
subquery.c.row_number <= self.limit_sequences_per_category_to
|
||||
)
|
||||
|
||||
@@ -633,9 +635,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
)
|
||||
)
|
||||
|
||||
if self.pick_frames_sql_clause:
|
||||
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
||||
logger.info("Applying the custom SQL clause.")
|
||||
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
|
||||
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||
pick_frames_criteria.append(sa.text(pick_frames_sql_clause))
|
||||
|
||||
if pick_frames_criteria:
|
||||
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
||||
@@ -698,9 +701,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
)
|
||||
)
|
||||
|
||||
if self.pick_frames_sql_clause:
|
||||
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
||||
logger.info(" applying custom SQL clause")
|
||||
where_conditions.append(sa.text(self.pick_frames_sql_clause))
|
||||
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||
where_conditions.append(sa.text(pick_frames_sql_clause))
|
||||
|
||||
if where_conditions:
|
||||
stmt = stmt.where(*where_conditions)
|
||||
|
||||
Reference in New Issue
Block a user