diff options
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/rmsprop.py')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/rmsprop.py | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py index 164ff0ea06..3de53405ec 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop.py @@ -22,7 +22,7 @@ A detailed description of rmsprop. - divide gradient by the root of this average mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 -mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon) +mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square) delta = - mom This implementation of RMSProp uses plain momentum, not Nesterov momentum. @@ -33,7 +33,7 @@ gradients, and uses that average to estimate the variance: mean_grad = decay * mean_square{t-1} + (1-decay) * gradient mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 mom = momentum * mom{t-1} + learning_rate * g_t / - sqrt(mean_square - mean_grad**2 + epsilon) + sqrt(mean_square - mean_grad**2) delta = - mom """ @@ -43,7 +43,6 @@ from __future__ import print_function from tensorflow.contrib.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops from tensorflow.python.training import training_ops @@ -87,7 +86,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): decay: A float hyperparameter. Discounting factor for the history/coming gradient. momentum: A float hyperparameter. - epsilon: A float hyperparameter. Small value to avoid zero denominator. + epsilon: A float hyperparameter. Small value to initialize the average + square gradient variable and avoid zero denominator. use_locking: If True use locks for update operation. centered: If True, gradients are normalized by the estimated variance of the gradient; if False, by the uncentered second moment. Setting this to @@ -106,10 +106,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): for v in var_list: - if v.get_shape().is_fully_defined(): - init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype) - else: - init_rms = array_ops.ones_like(v) + init_rms = state.get_hyper( + "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v) state.create_slot_with_initializer(v, init_rms, v.get_shape(), v.dtype.base_dtype, "rms") if self._centered: @@ -129,7 +127,9 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + # epsilon is now the rms initial value and is not added to the + # denominator anymore, hence calling the kernel op with epsilon=0. + 0, grad, use_locking=self._use_locking).op else: @@ -140,7 +140,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, use_locking=self._use_locking).op @@ -157,7 +157,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, use_locking=self._use_locking) else: @@ -168,7 +168,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, use_locking=self._use_locking) @@ -185,7 +185,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad.values, grad.indices, use_locking=self._use_locking) @@ -197,7 +197,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad.values, grad.indices, use_locking=self._use_locking) @@ -215,7 +215,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, indices, use_locking=self._use_locking) @@ -227,7 +227,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2): state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("decay", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), - state.get_hyper("epsilon", var.dtype.base_dtype), + 0, grad, indices, use_locking=self._use_locking) |