diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-13 13:09:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-13 13:13:54 -0700 |
commit | bb7541b96c49b06b5c13775f3666ae2b8450a457 (patch) | |
tree | c62b1c22878eb0166a370da3dd1d9ec338415040 /tensorflow/python/keras/callbacks.py | |
parent | 63e6b9bf43049472b33393df74de271b6aa33863 (diff) |
Automated rollback of commit 57527f7e47e3e67966b432065f510a601a4d8647
PiperOrigin-RevId: 204516578
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 28 |
1 files changed, 4 insertions, 24 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 5d66db232a..53d907a2cc 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -32,10 +32,8 @@ import numpy as np import six from tensorflow.python.keras import backend as K -from tensorflow.python.keras import optimizers from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops -from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary from tensorflow.python.util.tf_export import tf_export @@ -644,35 +642,17 @@ class LearningRateScheduler(Callback): self.verbose = verbose def on_epoch_begin(self, epoch, logs=None): - # TODO(yashkatariya): Change the property checking when the learning - # rate attribute is unified across all TF Optimizers. - if isinstance(self.model.optimizer, optimizers.TFOptimizer): - if not hasattr(self.model.optimizer.optimizer, '_lr') and not hasattr( - self.model.optimizer.optimizer, '_learning_rate'): - raise ValueError( - 'TF Optimizer must have a "_lr" or "_learning_rate" attribute.') - else: - opt = self.model.optimizer.optimizer - if hasattr(opt, '_lr'): - opt_lr = Variable(opt._lr) # pylint: disable=protected-access - elif hasattr(opt, '_learning_rate'): - opt_lr = Variable(opt._learning_rate) # pylint: disable=protected-access - else: - if not hasattr(self.model.optimizer, 'lr'): - raise ValueError('Optimizer must have a "lr" attribute.') - else: - opt = self.model.optimizer - opt_lr = opt.lr - + if not hasattr(self.model.optimizer, 'lr'): + raise ValueError('Optimizer must have a "lr" attribute.') try: # new API - lr = float(K.get_value(opt_lr)) + lr = float(K.get_value(self.model.optimizer.lr)) lr = self.schedule(epoch, lr) except TypeError: # Support for old API for backward compatibility lr = self.schedule(epoch) if not isinstance(lr, (float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') - K.set_value(opt_lr, lr) + K.set_value(self.model.optimizer.lr, lr) if self.verbose > 0: print('\nEpoch %05d: LearningRateScheduler reducing learning ' 'rate to %s.' % (epoch + 1, lr)) |