aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2/rmsprop.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/rmsprop.py')
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py
index 164ff0ea06..3de53405ec 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop.py
@@ -22,7 +22,7 @@ A detailed description of rmsprop.
- divide gradient by the root of this average
mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
-mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon)
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square)
delta = - mom
This implementation of RMSProp uses plain momentum, not Nesterov momentum.
@@ -33,7 +33,7 @@ gradients, and uses that average to estimate the variance:
mean_grad = decay * mean_square{t-1} + (1-decay) * gradient
mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
mom = momentum * mom{t-1} + learning_rate * g_t /
- sqrt(mean_square - mean_grad**2 + epsilon)
+ sqrt(mean_square - mean_grad**2)
delta = - mom
"""
@@ -43,7 +43,6 @@ from __future__ import print_function
from tensorflow.contrib.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.training import training_ops
@@ -87,7 +86,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
decay: A float hyperparameter. Discounting factor for the history/coming
gradient.
momentum: A float hyperparameter.
- epsilon: A float hyperparameter. Small value to avoid zero denominator.
+ epsilon: A float hyperparameter. Small value to initialize the average
+ square gradient variable and avoid zero denominator.
use_locking: If True use locks for update operation.
centered: If True, gradients are normalized by the estimated variance of
the gradient; if False, by the uncentered second moment. Setting this to
@@ -106,10 +106,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
for v in var_list:
- if v.get_shape().is_fully_defined():
- init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype)
- else:
- init_rms = array_ops.ones_like(v)
+ init_rms = state.get_hyper(
+ "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
state.create_slot_with_initializer(v, init_rms, v.get_shape(),
v.dtype.base_dtype, "rms")
if self._centered:
@@ -129,7 +127,9 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ # epsilon is now the rms initial value and is not added to the
+ # denominator anymore, hence calling the kernel op with epsilon=0.
+ 0,
grad,
use_locking=self._use_locking).op
else:
@@ -140,7 +140,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
use_locking=self._use_locking).op
@@ -157,7 +157,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
use_locking=self._use_locking)
else:
@@ -168,7 +168,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
use_locking=self._use_locking)
@@ -185,7 +185,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad.values,
grad.indices,
use_locking=self._use_locking)
@@ -197,7 +197,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad.values,
grad.indices,
use_locking=self._use_locking)
@@ -215,7 +215,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
indices,
use_locking=self._use_locking)
@@ -227,7 +227,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
indices,
use_locking=self._use_locking)