Always print message on test failure

Summary: make assertClose print its failure information even if a message is supplied.

Reviewed By: nikhilaravi

Differential Revision: D28799745

fbshipit-source-id: 787c8c356342420cd8f40fdc0b2aba036142298e
This commit is contained in:
Jeremy Reizenstein 2021-06-03 18:29:42 -07:00 committed by Facebook GitHub Bot
parent 7fd7de4451
commit 070ec550d3
2 changed files with 22 additions and 19 deletions

View File

@ -154,22 +154,24 @@ class TestCaseMixin(unittest.TestCase):
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
)
if not close and msg is None:
diff = backend.abs(input + 0.0 - other)
ratio = diff / backend.abs(other)
try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0))
if try_relative.all():
if backend == np:
# Avoid a weirdness with zero dimensional arrays.
ratio = np.array(ratio)
ratio[diff <= atol] = 0
extra = f" Max relative diff {ratio.max()}"
else:
extra = ""
shape = tuple(input.shape)
loc = np.unravel_index(diff.argmax(), shape)
max_diff = diff.max()
msg = f"Not close. Max diff {max_diff}.{extra} Shape {shape}. At {loc}."
self.fail(msg)
if close:
return
self.assertTrue(close, msg)
diff = backend.abs(input + 0.0 - other)
ratio = diff / backend.abs(other)
try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0))
if try_relative.all():
if backend == np:
# Avoid a weirdness with zero dimensional arrays.
ratio = np.array(ratio)
ratio[diff <= atol] = 0
extra = f" Max relative diff {ratio.max()}"
else:
extra = ""
shape = tuple(input.shape)
loc = np.unravel_index(int(diff.argmax()), shape)
max_diff = diff.max()
err = f"Not close. Max diff {max_diff}.{extra} Shape {shape}. At {loc}."
if msg is not None:
self.fail(f"{msg} {err}")
self.fail(err)

View File

@ -36,7 +36,8 @@ class TestOpsUtils(TestCaseMixin, unittest.TestCase):
self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std)
with self.assertRaises(AssertionError) as context:
self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg)
self.assertTrue(msg in str(context.exception))
self.assertIn(msg, str(context.exception))
self.assertIn("Not close", str(context.exception))
# test relative tolerance
assert torch.allclose(x, x_noise, rtol=100 * noise_std)