diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-10 11:35:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 11:40:02 -0700 |
commit | a8b2dd9f72fe78cca59d525230f5358430fec45c (patch) | |
tree | a31badb6bd83f88a1a7a2b44b35669ae436c344c /tensorflow/python/framework | |
parent | c5b14b334e89b9bcb0fd0199481318b8fdd65762 (diff) |
Fix unhelpful error message
For 99% of all usecases, if the expected shape differs from the actual shape, people will typically rerun with an additional print statement to see what the actual output was.
PiperOrigin-RevId: 212303323
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4bece9e25e..d63abd7f01 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1327,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase): def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual( - a.shape, b.shape, - "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)) + # When the array rank is small, print its contents. Numpy array printing is + # implemented using inefficient recursion so prints can cause tests to + # time out. + if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): + shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " + "%s.") % (a.shape, b.shape, b) + else: + shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, + b.shape) + self.assertEqual(a.shape, b.shape, shape_mismatch_msg) + if not np.allclose(a, b, rtol=rtol, atol=atol): # Prints more details than np.testing.assert_allclose. # |