diff options
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r-- | tensorflow/core/ops/training_ops.cc | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 5eb011684b..eabec80c2e 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -441,6 +441,9 @@ REGISTER_OP("ApplyRMSProp") .Attr("use_locking: bool = false") .Doc(R"doc( Update '*var' according to the RMSProp algorithm. +Note that in dense implement of this algorithm, ms and mom will +update even if the grad is zero, but in this sparse implement, ms +and mom will not update in iterations the grad is zero. mean_square = decay * mean_square + (1-decay) * gradient ** 2 Delta = learning_rate * gradient / sqrt(mean_square + epsilon) @@ -461,5 +464,46 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); + +REGISTER_OP("SparseApplyRMSProp") + .Input("var: Ref(T)") + .Input("ms: Ref(T)") + .Input("mom: Ref(T)") + .Input("lr: T") + .Input("rho: T") + .Input("momentum: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("indices: Tindices") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update '*var' according to the RMSProp algorithm. +Note that in dense implement of this algorithm, ms and mom will +update even if the grad is zero, but in this sparse implement, ms +and mom will not update in iterations the grad is zero. + +mean_square = decay * mean_square + (1-decay) * gradient ** 2 +Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + +ms <- rho * ms_{t-1} + (1-rho) * grad * grad +mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +var <- var - mom + +var: Should be from a Variable(). +ms: Should be from a Variable(). +mom: Should be from a Variable(). +lr: Scaling factor. Must be a scalar. +epsilon: Ridge term. Must be a scalar. +rho: Decay rate. Must be a scalar. +grad: The gradient. +indices: A vector of indices into the first dimension of var, ms and mom. +out: Same as "var". +use_locking: If `True`, updating of the var, m, and v tensors will be protected + by a lock; otherwise the behavior is undefined, but may exhibit less + contention. +)doc"); } // namespace tensorflow |