diff options
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r-- | tensorflow/core/ops/training_ops.cc | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index e6a9c0c018..1d24ea36a3 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -103,6 +103,28 @@ use_locking: If `True`, the subtraction will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); +REGISTER_OP("ApplyDelayCompensatedGradientDescent") + .Input("var: resource") + .Input("alpha: T") + .Input("delta: T") + .Input("lambda: T") + .Input("shadow: resource") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .SetShapeFn(ApplyGradientDescentShapeFn) + .Doc(R"doc( +var -= alpha * (delta + lambda * delta * (var - shadow)) +Update '*shadow' by changing it to the new value of 'var' + +var: Should be from a Variable(). +alpha: Scaling factor. Must be a scalar. +delta: The change. +lambda: The variance parameter. +shadow: Same as "var". +use_locking: If `True`, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c, bool sparse) { ShapeHandle unused; |