aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util_test.py')
-rw-r--r--tensorflow/python/framework/test_util_test.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 990e74d7e1..cb021c1170 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -193,6 +193,55 @@ class TestUtilTest(test_util.TensorFlowTestCase):
y = [15]
control_flow_ops.Assert(x, y).run()
+ def testAssertAllCloseAccordingToType(self):
+ # test float64
+ self.assertAllCloseAccordingToType(
+ np.asarray([1e-8], dtype=np.float64),
+ np.asarray([2e-8], dtype=np.float64),
+ rtol=1e-8, atol=1e-8
+ )
+
+ with (self.assertRaises(AssertionError)):
+ self.assertAllCloseAccordingToType(
+ np.asarray([1e-7], dtype=np.float64),
+ np.asarray([2e-7], dtype=np.float64),
+ rtol=1e-8, atol=1e-8
+ )
+
+ # test float32
+ self.assertAllCloseAccordingToType(
+ np.asarray([1e-7], dtype=np.float32),
+ np.asarray([2e-7], dtype=np.float32),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7
+ )
+
+ with (self.assertRaises(AssertionError)):
+ self.assertAllCloseAccordingToType(
+ np.asarray([1e-6], dtype=np.float32),
+ np.asarray([2e-6], dtype=np.float32),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7
+ )
+
+ # test float16
+ self.assertAllCloseAccordingToType(
+ np.asarray([1e-4], dtype=np.float16),
+ np.asarray([2e-4], dtype=np.float16),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7,
+ half_rtol=1e-4, half_atol=1e-4
+ )
+
+ with (self.assertRaises(AssertionError)):
+ self.assertAllCloseAccordingToType(
+ np.asarray([1e-3], dtype=np.float16),
+ np.asarray([2e-3], dtype=np.float16),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7,
+ half_rtol=1e-4, half_atol=1e-4
+ )
+
def testRandomSeed(self):
a = random.randint(1, 1000)
a_np_rand = np.random.rand(1)