aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r--tensorflow/core/ops/training_ops.cc12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index 28af0bbfe2..b260a588e8 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -488,11 +488,13 @@ REGISTER_OP("ApplyMomentum")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
+ .Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, false /* sparse */);
})
.Doc(R"doc(
-Update '*var' according to the momentum scheme.
+Update '*var' according to the momentum scheme. Set use_nesterov = True if you
+want to use Nesterov momentum.
accum = accum * momentum + grad
var -= lr * accum
@@ -506,6 +508,9 @@ out: Same as "var".
use_locking: If `True`, updating of the var and accum tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
+use_nesterov: If `True`, the tensor passed to compute grad will be
+var - lr * momentum * accum, so in the end, the var you get is actually
+var - lr * momentum * accum.
)doc");
REGISTER_OP("SparseApplyMomentum")
@@ -519,11 +524,13 @@ REGISTER_OP("SparseApplyMomentum")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, true /* sparse */);
})
.Doc(R"doc(
Update relevant entries in '*var' and '*accum' according to the momentum scheme.
+Set use_nesterov = True if you want to use Nesterov momentum.
That is for rows we have grad for, we update var and accum as follows:
@@ -540,6 +547,9 @@ out: Same as "var".
use_locking: If `True`, updating of the var and accum tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
+use_nesterov: If `True`, the tensor passed to compute grad will be
+var - lr * momentum * accum, so in the end, the var you get is actually
+var - lr * momentum * accum.
)doc");
static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) {