diff options
Diffstat (limited to 'tensorflow/python/framework/test_util_test.py')
-rw-r--r-- | tensorflow/python/framework/test_util_test.py | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 990e74d7e1..cb021c1170 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -193,6 +193,55 @@ class TestUtilTest(test_util.TensorFlowTestCase): y = [15] control_flow_ops.Assert(x, y).run() + def testAssertAllCloseAccordingToType(self): + # test float64 + self.assertAllCloseAccordingToType( + np.asarray([1e-8], dtype=np.float64), + np.asarray([2e-8], dtype=np.float64), + rtol=1e-8, atol=1e-8 + ) + + with (self.assertRaises(AssertionError)): + self.assertAllCloseAccordingToType( + np.asarray([1e-7], dtype=np.float64), + np.asarray([2e-7], dtype=np.float64), + rtol=1e-8, atol=1e-8 + ) + + # test float32 + self.assertAllCloseAccordingToType( + np.asarray([1e-7], dtype=np.float32), + np.asarray([2e-7], dtype=np.float32), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7 + ) + + with (self.assertRaises(AssertionError)): + self.assertAllCloseAccordingToType( + np.asarray([1e-6], dtype=np.float32), + np.asarray([2e-6], dtype=np.float32), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7 + ) + + # test float16 + self.assertAllCloseAccordingToType( + np.asarray([1e-4], dtype=np.float16), + np.asarray([2e-4], dtype=np.float16), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7, + half_rtol=1e-4, half_atol=1e-4 + ) + + with (self.assertRaises(AssertionError)): + self.assertAllCloseAccordingToType( + np.asarray([1e-3], dtype=np.float16), + np.asarray([2e-3], dtype=np.float16), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7, + half_rtol=1e-4, half_atol=1e-4 + ) + def testRandomSeed(self): a = random.randint(1, 1000) a_np_rand = np.random.rand(1) |