aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-13 13:09:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-13 13:13:54 -0700
commitbb7541b96c49b06b5c13775f3666ae2b8450a457 (patch)
treec62b1c22878eb0166a370da3dd1d9ec338415040 /tensorflow/python/keras/callbacks.py
parent63e6b9bf43049472b33393df74de271b6aa33863 (diff)
Automated rollback of commit 57527f7e47e3e67966b432065f510a601a4d8647
PiperOrigin-RevId: 204516578
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py28
1 files changed, 4 insertions, 24 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 5d66db232a..53d907a2cc 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -32,10 +32,8 @@ import numpy as np
import six
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras import optimizers
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
from tensorflow.python.util.tf_export import tf_export
@@ -644,35 +642,17 @@ class LearningRateScheduler(Callback):
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
- # TODO(yashkatariya): Change the property checking when the learning
- # rate attribute is unified across all TF Optimizers.
- if isinstance(self.model.optimizer, optimizers.TFOptimizer):
- if not hasattr(self.model.optimizer.optimizer, '_lr') and not hasattr(
- self.model.optimizer.optimizer, '_learning_rate'):
- raise ValueError(
- 'TF Optimizer must have a "_lr" or "_learning_rate" attribute.')
- else:
- opt = self.model.optimizer.optimizer
- if hasattr(opt, '_lr'):
- opt_lr = Variable(opt._lr) # pylint: disable=protected-access
- elif hasattr(opt, '_learning_rate'):
- opt_lr = Variable(opt._learning_rate) # pylint: disable=protected-access
- else:
- if not hasattr(self.model.optimizer, 'lr'):
- raise ValueError('Optimizer must have a "lr" attribute.')
- else:
- opt = self.model.optimizer
- opt_lr = opt.lr
-
+ if not hasattr(self.model.optimizer, 'lr'):
+ raise ValueError('Optimizer must have a "lr" attribute.')
try: # new API
- lr = float(K.get_value(opt_lr))
+ lr = float(K.get_value(self.model.optimizer.lr))
lr = self.schedule(epoch, lr)
except TypeError: # Support for old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
- K.set_value(opt_lr, lr)
+ K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))