aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 11:35:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 11:40:02 -0700
commita8b2dd9f72fe78cca59d525230f5358430fec45c (patch)
treea31badb6bd83f88a1a7a2b44b35669ae436c344c /tensorflow/python/framework
parentc5b14b334e89b9bcb0fd0199481318b8fdd65762 (diff)
Fix unhelpful error message
For 99% of all usecases, if the expected shape differs from the actual shape, people will typically rerun with an additional print statement to see what the actual output was. PiperOrigin-RevId: 212303323
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/test_util.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4bece9e25e..d63abd7f01 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1327,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase):
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(
- a.shape, b.shape,
- "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+ # When the array rank is small, print its contents. Numpy array printing is
+ # implemented using inefficient recursion so prints can cause tests to
+ # time out.
+ if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
+ shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
+ "%s.") % (a.shape, b.shape, b)
+ else:
+ shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
+ b.shape)
+ self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
#