diff options
author | Francois Chollet <fchollet@google.com> | 2018-06-12 14:03:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-12 14:08:19 -0700 |
commit | abfdf45dcdfe366376d859bf29166c0ad16d9993 (patch) | |
tree | f6511da4fb72630f50e4c64b7cc93092c0abbbb7 /tensorflow/python/keras/callbacks.py | |
parent | 9c7ba7503402bd02045f2464ef315db69699d6a9 (diff) |
Minor fixes in tf.keras codebase in preparation for Keras 2.2.0 API support.
PiperOrigin-RevId: 200276422
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 8061d47295..70b6a8431a 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -635,7 +635,11 @@ class LearningRateScheduler(Callback): def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, 'lr'): raise ValueError('Optimizer must have a "lr" attribute.') - lr = self.schedule(epoch) + try: # new API + lr = float(K.get_value(self.model.optimizer.lr)) + lr = self.schedule(epoch, lr) + except TypeError: # Support for old API for backward compatibility + lr = self.schedule(epoch) if not isinstance(lr, (float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') |