diff --git a/tests/common_testing.py b/tests/common_testing.py index 6e2cdc46..0f32e7e2 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -154,22 +154,24 @@ class TestCaseMixin(unittest.TestCase): input, other, rtol=rtol, atol=atol, equal_nan=equal_nan ) - if not close and msg is None: - diff = backend.abs(input + 0.0 - other) - ratio = diff / backend.abs(other) - try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0)) - if try_relative.all(): - if backend == np: - # Avoid a weirdness with zero dimensional arrays. - ratio = np.array(ratio) - ratio[diff <= atol] = 0 - extra = f" Max relative diff {ratio.max()}" - else: - extra = "" - shape = tuple(input.shape) - loc = np.unravel_index(diff.argmax(), shape) - max_diff = diff.max() - msg = f"Not close. Max diff {max_diff}.{extra} Shape {shape}. At {loc}." - self.fail(msg) + if close: + return - self.assertTrue(close, msg) + diff = backend.abs(input + 0.0 - other) + ratio = diff / backend.abs(other) + try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0)) + if try_relative.all(): + if backend == np: + # Avoid a weirdness with zero dimensional arrays. + ratio = np.array(ratio) + ratio[diff <= atol] = 0 + extra = f" Max relative diff {ratio.max()}" + else: + extra = "" + shape = tuple(input.shape) + loc = np.unravel_index(int(diff.argmax()), shape) + max_diff = diff.max() + err = f"Not close. Max diff {max_diff}.{extra} Shape {shape}. At {loc}." + if msg is not None: + self.fail(f"{msg} {err}") + self.fail(err) diff --git a/tests/test_common_testing.py b/tests/test_common_testing.py index c8976ad5..16e36fad 100644 --- a/tests/test_common_testing.py +++ b/tests/test_common_testing.py @@ -36,7 +36,8 @@ class TestOpsUtils(TestCaseMixin, unittest.TestCase): self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std) with self.assertRaises(AssertionError) as context: self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg) - self.assertTrue(msg in str(context.exception)) + self.assertIn(msg, str(context.exception)) + self.assertIn("Not close", str(context.exception)) # test relative tolerance assert torch.allclose(x, x_noise, rtol=100 * noise_std)