2 Commits

Author SHA1 Message Date
Jeremy Reizenstein
cbcae096a0 Add atol=1e-4 to assertClose calls in test_inverse for Translate
Summary:
Added `atol=1e-4` tolerance parameter to the `assertClose` calls on lines 682 and 683 in the `test_inverse` method of `TestTranslate` class.

This is a retry of D90225548

Reviewed By: sgrigory

Differential Revision: D90682979

fbshipit-source-id: ac13f000174dd9962326296e1c3116d0d39c7751
2026-01-14 08:57:43 -08:00
generatedunixname537391475639613
5b1cce56bc Fix for T251460511 ("Your diff, D90498281, broke one test")
Reviewed By: sgrigory

Differential Revision: D90649493

fbshipit-source-id: 2a77c45ec8e6e5aa0a20437a765fbb9f0b566406
2026-01-14 08:53:26 -08:00
2 changed files with 12 additions and 8 deletions

View File

@@ -483,9 +483,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
*self._get_pick_filters(),
*self._get_exclude_filters(),
]
if self.pick_sequences_sql_clause:
if pick_sequences_sql_clause := self.pick_sequences_sql_clause:
print("Applying the custom SQL clause.")
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
# pyre-ignore[6]: TextClause is compatible with where conditions
where_conditions.append(sa.text(pick_sequences_sql_clause))
def add_where(stmt):
return stmt.where(*where_conditions) if where_conditions else stmt
@@ -505,6 +506,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
subquery = add_where(subquery).subquery()
stmt = sa.select(subquery.c.sequence_name).where(
# pyre-ignore[6]: SQLAlchemy column comparison returns ColumnElement, not bool
subquery.c.row_number <= self.limit_sequences_per_category_to
)
@@ -633,9 +635,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
)
)
if self.pick_frames_sql_clause:
if pick_frames_sql_clause := self.pick_frames_sql_clause:
logger.info("Applying the custom SQL clause.")
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
# pyre-ignore[6]: TextClause is compatible with where conditions
pick_frames_criteria.append(sa.text(pick_frames_sql_clause))
if pick_frames_criteria:
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
@@ -698,9 +701,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
)
)
if self.pick_frames_sql_clause:
if pick_frames_sql_clause := self.pick_frames_sql_clause:
logger.info(" applying custom SQL clause")
where_conditions.append(sa.text(self.pick_frames_sql_clause))
# pyre-ignore[6]: TextClause is compatible with where conditions
where_conditions.append(sa.text(pick_frames_sql_clause))
if where_conditions:
stmt = stmt.where(*where_conditions)

View File

@@ -679,8 +679,8 @@ class TestTranslate(TestCaseMixin, unittest.TestCase):
im = t.inverse()._matrix
im_2 = t._matrix.inverse()
im_comp = t.get_matrix().inverse()
self.assertClose(im, im_comp)
self.assertClose(im, im_2)
self.assertClose(im, im_comp, atol=1e-4)
self.assertClose(im, im_2, atol=1e-4)
def test_get_item(self, batch_size=5):
device = torch.device("cuda:0")