From 070ec550d387185c691fe02a7a3b8459ab3b6d9f Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 3 Jun 2021 18:29:42 -0700 Subject: [PATCH] Always print message on test failure Summary: make assertClose print its failure information even if a message is supplied. Reviewed By: nikhilaravi Differential Revision: D28799745 fbshipit-source-id: 787c8c356342420cd8f40fdc0b2aba036142298e --- tests/common_testing.py | 38 +++++++++++++++++++----------------- tests/test_common_testing.py | 3 ++- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/common_testing.py b/tests/common_testing.py index 6e2cdc46..0f32e7e2 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -154,22 +154,24 @@ class TestCaseMixin(unittest.TestCase): input, other, rtol=rtol, atol=atol, equal_nan=equal_nan ) - if not close and msg is None: - diff = backend.abs(input + 0.0 - other) - ratio = diff / backend.abs(other) - try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0)) - if try_relative.all(): - if backend == np: - # Avoid a weirdness with zero dimensional arrays. - ratio = np.array(ratio) - ratio[diff <= atol] = 0 - extra = f" Max relative diff {ratio.max()}" - else: - extra = "" - shape = tuple(input.shape) - loc = np.unravel_index(diff.argmax(), shape) - max_diff = diff.max() - msg = f"Not close. Max diff {max_diff}.{extra} Shape {shape}. At {loc}." - self.fail(msg) + if close: + return - self.assertTrue(close, msg) + diff = backend.abs(input + 0.0 - other) + ratio = diff / backend.abs(other) + try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0)) + if try_relative.all(): + if backend == np: + # Avoid a weirdness with zero dimensional arrays. + ratio = np.array(ratio) + ratio[diff <= atol] = 0 + extra = f" Max relative diff {ratio.max()}" + else: + extra = "" + shape = tuple(input.shape) + loc = np.unravel_index(int(diff.argmax()), shape) + max_diff = diff.max() + err = f"Not close. Max diff {max_diff}.{extra} Shape {shape}. At {loc}." + if msg is not None: + self.fail(f"{msg} {err}") + self.fail(err) diff --git a/tests/test_common_testing.py b/tests/test_common_testing.py index c8976ad5..16e36fad 100644 --- a/tests/test_common_testing.py +++ b/tests/test_common_testing.py @@ -36,7 +36,8 @@ class TestOpsUtils(TestCaseMixin, unittest.TestCase): self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std) with self.assertRaises(AssertionError) as context: self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg) - self.assertTrue(msg in str(context.exception)) + self.assertIn(msg, str(context.exception)) + self.assertIn("Not close", str(context.exception)) # test relative tolerance assert torch.allclose(x, x_noise, rtol=100 * noise_std)