diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 14:57:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 14:57:10 -0700 |
commit | 13446601830e70834fa1bab2a061fab8a4150fa7 (patch) | |
tree | 470ee6dc4fcfc289644255995a05d26bbd51d462 /tensorflow/python/framework | |
parent | aab3c53e1484404a70565324d1231c4e6ead7425 (diff) | |
parent | d5d8a1bd06751b3ad166380a0a0ca00a3412145b (diff) |
Merge pull request #22006 from facaiy:CLN/remove_print_for_assert
PiperOrigin-RevId: 214335741
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 30 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util_test.py | 8 |
2 files changed, 24 insertions, 14 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index e55a1b84ee..cd0b03be43 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1401,35 +1401,36 @@ class TensorFlowTestCase(googletest.TestCase): b.shape) self.assertEqual(a.shape, b.shape, shape_mismatch_msg) + msgs = [msg] if not np.allclose(a, b, rtol=rtol, atol=atol): - # Prints more details than np.testing.assert_allclose. + # Adds more details to np.testing.assert_allclose. # # NOTE: numpy.allclose (and numpy.testing.assert_allclose) # checks whether two arrays are element-wise equal within a # tolerance. The relative difference (rtol * abs(b)) and the # absolute difference atol are added together to compare against # the absolute difference between a and b. Here, we want to - # print out which elements violate such conditions. + # tell user which elements violate such conditions. cond = np.logical_or( np.abs(a - b) > atol + rtol * np.abs(b), np.isnan(a) != np.isnan(b)) if a.ndim: x = a[np.where(cond)] y = b[np.where(cond)] - print("not close where = ", np.where(cond)) + msgs.append("not close where = {}".format(np.where(cond))) else: # np.where is broken for scalars x, y = a, b - print("not close lhs = ", x) - print("not close rhs = ", y) - print("not close dif = ", np.abs(x - y)) - print("not close tol = ", atol + rtol * np.abs(y)) - print("dtype = %s, shape = %s" % (a.dtype, a.shape)) + msgs.append("not close lhs = {}".format(x)) + msgs.append("not close rhs = {}".format(y)) + msgs.append("not close dif = {}".format(np.abs(x - y))) + msgs.append("not close tol = {}".format(atol + rtol * np.abs(y))) + msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape)) # TODO(xpan): There seems to be a bug: # tensorflow/compiler/tests:binary_ops_test pass with float32 # nan even though the equal_nan is False by default internally. np.testing.assert_allclose( - a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True) + a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True) def _assertAllCloseRecursive(self, a, @@ -1611,19 +1612,20 @@ class TensorFlowTestCase(googletest.TestCase): np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype ]): same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) + msgs = [msg] if not np.all(same): - # Prints more details than np.testing.assert_array_equal. + # Adds more details to np.testing.assert_array_equal. diff = np.logical_not(same) if a.ndim: x = a[np.where(diff)] y = b[np.where(diff)] - print("not equal where = ", np.where(diff)) + msgs.append("not equal where = {}".format(np.where(diff))) else: # np.where is broken for scalars x, y = a, b - print("not equal lhs = ", x) - print("not equal rhs = ", y) - np.testing.assert_array_equal(a, b, err_msg=msg) + msgs.append("not equal lhs = {}".format(x)) + msgs.append("not equal rhs = {}".format(y)) + np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) def assertAllGreater(self, a, comparison_target): """Assert element values are all greater than a target value. diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index c4f8fa9108..22189afa59 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -268,6 +268,11 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertAllClose(7, 7 + 1e-5) @test_util.run_in_graph_and_eager_modes + def testAllCloseList(self): + with self.assertRaisesRegexp(AssertionError, r"not close dif"): + self.assertAllClose([0], [1]) + + @test_util.run_in_graph_and_eager_modes def testAllCloseDictToNonDict(self): with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"): self.assertAllClose(1, {"a": 1}) @@ -452,6 +457,9 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertAllEqual([120] * 3, k) self.assertAllEqual([20] * 3, j) + with self.assertRaisesRegexp(AssertionError, r"not equal lhs"): + self.assertAllEqual([0] * 3, k) + @test_util.run_in_graph_and_eager_modes def testAssertNotAllClose(self): # Test with arrays |