aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/learning_rate_decay.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/learning_rate_decay.py')
-rw-r--r--tensorflow/python/training/learning_rate_decay.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index 51190264e8..fd195a7965 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -356,7 +356,15 @@ def natural_exp_decay(learning_rate,
The function returns the decayed learning rate. It is computed as:
```python
- decayed_learning_rate = learning_rate * exp(-decay_rate * global_step)
+ decayed_learning_rate = learning_rate * exp(-decay_rate * global_step /
+ decay_step)
+ ```
+
+ or, if `staircase` is `True`, as:
+
+ ```python
+ decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step /
+ decay_step))
```
Example: decay exponentially with a base of 0.96:
@@ -365,8 +373,10 @@ def natural_exp_decay(learning_rate,
...
global_step = tf.Variable(0, trainable=False)
learning_rate = 0.1
+ decay_steps = 5
k = 0.5
- learning_rate = tf.train.exponential_time_decay(learning_rate, global_step, k)
+ learning_rate = tf.train.natural_exp_decay(learning_rate, global_step,
+ decay_steps, k)
# Passing global_step to minimize() will increment it at each step.
learning_step = (