diff options
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r-- | tensorflow/core/ops/training_ops.cc | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 28af0bbfe2..b260a588e8 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -488,11 +488,13 @@ REGISTER_OP("ApplyMomentum") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .SetShapeFn([](InferenceContext* c) { return ApplyMomentumShapeFn(c, false /* sparse */); }) .Doc(R"doc( -Update '*var' according to the momentum scheme. +Update '*var' according to the momentum scheme. Set use_nesterov = True if you +want to use Nesterov momentum. accum = accum * momentum + grad var -= lr * accum @@ -506,6 +508,9 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, the tensor passed to compute grad will be +var - lr * momentum * accum, so in the end, the var you get is actually +var - lr * momentum * accum. )doc"); REGISTER_OP("SparseApplyMomentum") @@ -519,11 +524,13 @@ REGISTER_OP("SparseApplyMomentum") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .SetShapeFn([](InferenceContext* c) { return ApplyMomentumShapeFn(c, true /* sparse */); }) .Doc(R"doc( Update relevant entries in '*var' and '*accum' according to the momentum scheme. +Set use_nesterov = True if you want to use Nesterov momentum. That is for rows we have grad for, we update var and accum as follows: @@ -540,6 +547,9 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, the tensor passed to compute grad will be +var - lr * momentum * accum, so in the end, the var you get is actually +var - lr * momentum * accum. )doc"); static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) { |