mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-12 15:35:58 +08:00
Compare commits
2 Commits
0c3b204375
...
cbcae096a0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cbcae096a0 | ||
|
|
5b1cce56bc |
@@ -483,9 +483,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
*self._get_pick_filters(),
|
*self._get_pick_filters(),
|
||||||
*self._get_exclude_filters(),
|
*self._get_exclude_filters(),
|
||||||
]
|
]
|
||||||
if self.pick_sequences_sql_clause:
|
if pick_sequences_sql_clause := self.pick_sequences_sql_clause:
|
||||||
print("Applying the custom SQL clause.")
|
print("Applying the custom SQL clause.")
|
||||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||||
|
where_conditions.append(sa.text(pick_sequences_sql_clause))
|
||||||
|
|
||||||
def add_where(stmt):
|
def add_where(stmt):
|
||||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||||
@@ -505,6 +506,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
|
|
||||||
subquery = add_where(subquery).subquery()
|
subquery = add_where(subquery).subquery()
|
||||||
stmt = sa.select(subquery.c.sequence_name).where(
|
stmt = sa.select(subquery.c.sequence_name).where(
|
||||||
|
# pyre-ignore[6]: SQLAlchemy column comparison returns ColumnElement, not bool
|
||||||
subquery.c.row_number <= self.limit_sequences_per_category_to
|
subquery.c.row_number <= self.limit_sequences_per_category_to
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -633,9 +635,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pick_frames_sql_clause:
|
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
||||||
logger.info("Applying the custom SQL clause.")
|
logger.info("Applying the custom SQL clause.")
|
||||||
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
|
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||||
|
pick_frames_criteria.append(sa.text(pick_frames_sql_clause))
|
||||||
|
|
||||||
if pick_frames_criteria:
|
if pick_frames_criteria:
|
||||||
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
||||||
@@ -698,9 +701,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pick_frames_sql_clause:
|
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
||||||
logger.info(" applying custom SQL clause")
|
logger.info(" applying custom SQL clause")
|
||||||
where_conditions.append(sa.text(self.pick_frames_sql_clause))
|
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||||
|
where_conditions.append(sa.text(pick_frames_sql_clause))
|
||||||
|
|
||||||
if where_conditions:
|
if where_conditions:
|
||||||
stmt = stmt.where(*where_conditions)
|
stmt = stmt.where(*where_conditions)
|
||||||
|
|||||||
@@ -679,8 +679,8 @@ class TestTranslate(TestCaseMixin, unittest.TestCase):
|
|||||||
im = t.inverse()._matrix
|
im = t.inverse()._matrix
|
||||||
im_2 = t._matrix.inverse()
|
im_2 = t._matrix.inverse()
|
||||||
im_comp = t.get_matrix().inverse()
|
im_comp = t.get_matrix().inverse()
|
||||||
self.assertClose(im, im_comp)
|
self.assertClose(im, im_comp, atol=1e-4)
|
||||||
self.assertClose(im, im_2)
|
self.assertClose(im, im_2, atol=1e-4)
|
||||||
|
|
||||||
def test_get_item(self, batch_size=5):
|
def test_get_item(self, batch_size=5):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
|||||||
Reference in New Issue
Block a user