diff --git a/tests/common_testing.py b/tests/common_testing.py index 0f32e7e2..b1026257 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -120,7 +120,7 @@ class TestCaseMixin(unittest.TestCase): # all(norm_fn(input - other) <= atol + rtol * norm_fn(other)). self.assertClose( - diff + other_, other_, rtol=rtol, atol=atol, equal_nan=equal_nan + diff + other_, other_, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg ) def assertClose(