aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-06-12 14:03:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 14:08:19 -0700
commitabfdf45dcdfe366376d859bf29166c0ad16d9993 (patch)
treef6511da4fb72630f50e4c64b7cc93092c0abbbb7 /tensorflow/python/keras/callbacks.py
parent9c7ba7503402bd02045f2464ef315db69699d6a9 (diff)
Minor fixes in tf.keras codebase in preparation for Keras 2.2.0 API support.
PiperOrigin-RevId: 200276422
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 8061d47295..70b6a8431a 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -635,7 +635,11 @@ class LearningRateScheduler(Callback):
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
- lr = self.schedule(epoch)
+ try: # new API
+ lr = float(K.get_value(self.model.optimizer.lr))
+ lr = self.schedule(epoch, lr)
+ except TypeError: # Support for old API for backward compatibility
+ lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')