diff options
Diffstat (limited to 'tensorflow/python/training/learning_rate_decay.py')
-rw-r--r-- | tensorflow/python/training/learning_rate_decay.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 51190264e8..fd195a7965 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -356,7 +356,15 @@ def natural_exp_decay(learning_rate, The function returns the decayed learning rate. It is computed as: ```python - decayed_learning_rate = learning_rate * exp(-decay_rate * global_step) + decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / + decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step / + decay_step)) ``` Example: decay exponentially with a base of 0.96: @@ -365,8 +373,10 @@ def natural_exp_decay(learning_rate, ... global_step = tf.Variable(0, trainable=False) learning_rate = 0.1 + decay_steps = 5 k = 0.5 - learning_rate = tf.train.exponential_time_decay(learning_rate, global_step, k) + learning_rate = tf.train.natural_exp_decay(learning_rate, global_step, + decay_steps, k) # Passing global_step to minimize() will increment it at each step. learning_step = ( |