aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/adagrad_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/adagrad_test.py')
-rw-r--r--tensorflow/compiler/tests/adagrad_test.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py
index a5c5885b42..9a93b32164 100644
--- a/tensorflow/compiler/tests/adagrad_test.py
+++ b/tensorflow/compiler/tests/adagrad_test.py
@@ -49,9 +49,11 @@ class AdagradOptimizerTest(XLATestCase):
ada_update.run()
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
+ float_rtol=1e-5)
self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
+ float_rtol=1e-5)
def testTensorLearningRate(self):
for dtype in self.float_types:
@@ -73,9 +75,11 @@ class AdagradOptimizerTest(XLATestCase):
ada_update.run()
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
+ float_rtol=1e-5)
self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
+ float_rtol=1e-5)
def testSharing(self):
for dtype in self.float_types:
@@ -107,9 +111,11 @@ class AdagradOptimizerTest(XLATestCase):
ada_update1.run()
# Validate updated params (the same as with only 1 Adagrad).
self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
+ float_rtol=1e-5)
self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
+ float_rtol=1e-5)
if __name__ == "__main__":