diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 209a8a26..f2b2b2ab 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -679,8 +679,8 @@ class TestTranslate(TestCaseMixin, unittest.TestCase): im = t.inverse()._matrix im_2 = t._matrix.inverse() im_comp = t.get_matrix().inverse() - self.assertClose(im, im_comp) - self.assertClose(im, im_2) + self.assertClose(im, im_comp, atol=1e-4) + self.assertClose(im, im_2, atol=1e-4) def test_get_item(self, batch_size=5): device = torch.device("cuda:0")