diff options
author | 2018-08-21 12:59:45 -0700 | |
---|---|---|
committer | 2018-08-21 13:03:34 -0700 | |
commit | fce0a4eaab4b2dfa59ffed3a3d11bb0c82d98263 (patch) | |
tree | f99d82d7a535b9e2d3271831e9c9452ba1df1525 /tensorflow/python/keras/optimizers_test.py | |
parent | d648d7e6e12774d5c60418a899d15b81a387c770 (diff) |
Minor fix to allow iterations variable to update in eager mode
PiperOrigin-RevId: 209644988
Diffstat (limited to 'tensorflow/python/keras/optimizers_test.py')
-rw-r--r-- | tensorflow/python/keras/optimizers_test.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py index 4d295351f5..22197938a5 100644 --- a/tensorflow/python/keras/optimizers_test.py +++ b/tensorflow/python/keras/optimizers_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.training.adam import AdamOptimizer @@ -153,6 +155,7 @@ class KerasOptimizersTest(test.TestCase): with self.assertRaises(NotImplementedError): optimizer.from_config(None) + @test_util.run_in_graph_and_eager_modes def test_tfoptimizer_iterations(self): with self.test_session(): optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01)) @@ -169,11 +172,15 @@ class KerasOptimizersTest(test.TestCase): verbose=0) self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11) - model.fit(np.random.random((20, 3)), - np.random.random((20, 2)), - steps_per_epoch=8, - verbose=0) - self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 19) + if not context.executing_eagerly(): + # TODO(kathywu): investigate why training with an array input and + # setting the argument steps_per_epoch does not work in eager mode. + model.fit(np.random.random((20, 3)), + np.random.random((20, 2)), + steps_per_epoch=8, + verbose=0) + self.assertEqual( + keras.backend.get_value(model.optimizer.iterations), 19) def test_negative_clipvalue_or_clipnorm(self): with self.assertRaises(ValueError): |