aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 14:57:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 14:57:10 -0700
commit13446601830e70834fa1bab2a061fab8a4150fa7 (patch)
tree470ee6dc4fcfc289644255995a05d26bbd51d462 /tensorflow/python/framework
parentaab3c53e1484404a70565324d1231c4e6ead7425 (diff)
parentd5d8a1bd06751b3ad166380a0a0ca00a3412145b (diff)
Merge pull request #22006 from facaiy:CLN/remove_print_for_assert
PiperOrigin-RevId: 214335741
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/test_util.py30
-rw-r--r--tensorflow/python/framework/test_util_test.py8
2 files changed, 24 insertions, 14 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index e55a1b84ee..cd0b03be43 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1401,35 +1401,36 @@ class TensorFlowTestCase(googletest.TestCase):
b.shape)
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+ msgs = [msg]
if not np.allclose(a, b, rtol=rtol, atol=atol):
- # Prints more details than np.testing.assert_allclose.
+ # Adds more details to np.testing.assert_allclose.
#
# NOTE: numpy.allclose (and numpy.testing.assert_allclose)
# checks whether two arrays are element-wise equal within a
# tolerance. The relative difference (rtol * abs(b)) and the
# absolute difference atol are added together to compare against
# the absolute difference between a and b. Here, we want to
- # print out which elements violate such conditions.
+ # tell user which elements violate such conditions.
cond = np.logical_or(
np.abs(a - b) > atol + rtol * np.abs(b),
np.isnan(a) != np.isnan(b))
if a.ndim:
x = a[np.where(cond)]
y = b[np.where(cond)]
- print("not close where = ", np.where(cond))
+ msgs.append("not close where = {}".format(np.where(cond)))
else:
# np.where is broken for scalars
x, y = a, b
- print("not close lhs = ", x)
- print("not close rhs = ", y)
- print("not close dif = ", np.abs(x - y))
- print("not close tol = ", atol + rtol * np.abs(y))
- print("dtype = %s, shape = %s" % (a.dtype, a.shape))
+ msgs.append("not close lhs = {}".format(x))
+ msgs.append("not close rhs = {}".format(y))
+ msgs.append("not close dif = {}".format(np.abs(x - y)))
+ msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
+ msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
# TODO(xpan): There seems to be a bug:
# tensorflow/compiler/tests:binary_ops_test pass with float32
# nan even though the equal_nan is False by default internally.
np.testing.assert_allclose(
- a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
+ a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
def _assertAllCloseRecursive(self,
a,
@@ -1611,19 +1612,20 @@ class TensorFlowTestCase(googletest.TestCase):
np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
]):
same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
+ msgs = [msg]
if not np.all(same):
- # Prints more details than np.testing.assert_array_equal.
+ # Adds more details to np.testing.assert_array_equal.
diff = np.logical_not(same)
if a.ndim:
x = a[np.where(diff)]
y = b[np.where(diff)]
- print("not equal where = ", np.where(diff))
+ msgs.append("not equal where = {}".format(np.where(diff)))
else:
# np.where is broken for scalars
x, y = a, b
- print("not equal lhs = ", x)
- print("not equal rhs = ", y)
- np.testing.assert_array_equal(a, b, err_msg=msg)
+ msgs.append("not equal lhs = {}".format(x))
+ msgs.append("not equal rhs = {}".format(y))
+ np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
def assertAllGreater(self, a, comparison_target):
"""Assert element values are all greater than a target value.
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index c4f8fa9108..22189afa59 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -268,6 +268,11 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllClose(7, 7 + 1e-5)
@test_util.run_in_graph_and_eager_modes
+ def testAllCloseList(self):
+ with self.assertRaisesRegexp(AssertionError, r"not close dif"):
+ self.assertAllClose([0], [1])
+
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose(1, {"a": 1})
@@ -452,6 +457,9 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllEqual([120] * 3, k)
self.assertAllEqual([20] * 3, j)
+ with self.assertRaisesRegexp(AssertionError, r"not equal lhs"):
+ self.assertAllEqual([0] * 3, k)
+
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays