aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r--tensorflow/core/ops/training_ops.cc44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index 5eb011684b..eabec80c2e 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -441,6 +441,9 @@ REGISTER_OP("ApplyRMSProp")
.Attr("use_locking: bool = false")
.Doc(R"doc(
Update '*var' according to the RMSProp algorithm.
+Note that in dense implement of this algorithm, ms and mom will
+update even if the grad is zero, but in this sparse implement, ms
+and mom will not update in iterations the grad is zero.
mean_square = decay * mean_square + (1-decay) * gradient ** 2
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
@@ -461,5 +464,46 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
)doc");
+
+REGISTER_OP("SparseApplyRMSProp")
+ .Input("var: Ref(T)")
+ .Input("ms: Ref(T)")
+ .Input("mom: Ref(T)")
+ .Input("lr: T")
+ .Input("rho: T")
+ .Input("momentum: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Input("indices: Tindices")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the RMSProp algorithm.
+Note that in dense implement of this algorithm, ms and mom will
+update even if the grad is zero, but in this sparse implement, ms
+and mom will not update in iterations the grad is zero.
+
+mean_square = decay * mean_square + (1-decay) * gradient ** 2
+Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+
+ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+var <- var - mom
+
+var: Should be from a Variable().
+ms: Should be from a Variable().
+mom: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+epsilon: Ridge term. Must be a scalar.
+rho: Decay rate. Must be a scalar.
+grad: The gradient.
+indices: A vector of indices into the first dimension of var, ms and mom.
+out: Same as "var".
+use_locking: If `True`, updating of the var, m, and v tensors will be protected
+ by a lock; otherwise the behavior is undefined, but may exhibit less
+ contention.
+)doc");
} // namespace tensorflow