diff options
-rw-r--r-- | tensorflow/python/framework/test_util.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4bece9e25e..d63abd7f01 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1327,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase): def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual( - a.shape, b.shape, - "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)) + # When the array rank is small, print its contents. Numpy array printing is + # implemented using inefficient recursion so prints can cause tests to + # time out. + if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): + shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " + "%s.") % (a.shape, b.shape, b) + else: + shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, + b.shape) + self.assertEqual(a.shape, b.shape, shape_mismatch_msg) + if not np.allclose(a, b, rtol=rtol, atol=atol): # Prints more details than np.testing.assert_allclose. # |