diff options
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 8c73977fc9..f2fd687adf 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -491,7 +491,9 @@ class TensorFlowTestCase(googletest.TestCase): print("dtype = %s, shape = %s" % (a.dtype, a.shape)) np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) - def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6): + def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6, + float_rtol=1e-6, float_atol=1e-6, + half_rtol=1e-3, half_atol=1e-3): """Like assertAllClose, but also suitable for comparing fp16 arrays. In particular, the tolerance is reduced to 1e-3 if at least @@ -502,12 +504,19 @@ class TensorFlowTestCase(googletest.TestCase): b: a numpy ndarray or anything can be converted to one. rtol: relative tolerance atol: absolute tolerance + float_rtol: relative tolerance for float32 + float_atol: absolute tolerance for float32 + half_rtol: relative tolerance for float16 + half_atol: absolute tolerance for float16 """ a = self._GetNdArray(a) b = self._GetNdArray(b) + if a.dtype == np.float32 or b.dtype == np.float32: + rtol = max(rtol, float_rtol) + atol = max(atol, float_atol) if a.dtype == np.float16 or b.dtype == np.float16: - rtol = max(rtol, 1e-3) - atol = max(atol, 1e-3) + rtol = max(rtol, half_rtol) + atol = max(atol, half_atol) self.assertAllClose(a, b, rtol=rtol, atol=atol) |