aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 8c73977fc9..f2fd687adf 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -491,7 +491,9 @@ class TensorFlowTestCase(googletest.TestCase):
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
- def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6):
+ def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6,
+ float_rtol=1e-6, float_atol=1e-6,
+ half_rtol=1e-3, half_atol=1e-3):
"""Like assertAllClose, but also suitable for comparing fp16 arrays.
In particular, the tolerance is reduced to 1e-3 if at least
@@ -502,12 +504,19 @@ class TensorFlowTestCase(googletest.TestCase):
b: a numpy ndarray or anything can be converted to one.
rtol: relative tolerance
atol: absolute tolerance
+ float_rtol: relative tolerance for float32
+ float_atol: absolute tolerance for float32
+ half_rtol: relative tolerance for float16
+ half_atol: absolute tolerance for float16
"""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
+ if a.dtype == np.float32 or b.dtype == np.float32:
+ rtol = max(rtol, float_rtol)
+ atol = max(atol, float_atol)
if a.dtype == np.float16 or b.dtype == np.float16:
- rtol = max(rtol, 1e-3)
- atol = max(atol, 1e-3)
+ rtol = max(rtol, half_rtol)
+ atol = max(atol, half_atol)
self.assertAllClose(a, b, rtol=rtol, atol=atol)