aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r--tensorflow/core/ops/training_ops.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index e6a9c0c018..1d24ea36a3 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -103,6 +103,28 @@ use_locking: If `True`, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
+REGISTER_OP("ApplyDelayCompensatedGradientDescent")
+ .Input("var: resource")
+ .Input("alpha: T")
+ .Input("delta: T")
+ .Input("lambda: T")
+ .Input("shadow: resource")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ApplyGradientDescentShapeFn)
+ .Doc(R"doc(
+var -= alpha * (delta + lambda * delta * (var - shadow))
+Update '*shadow' by changing it to the new value of 'var'
+
+var: Should be from a Variable().
+alpha: Scaling factor. Must be a scalar.
+delta: The change.
+lambda: The variance parameter.
+shadow: Same as "var".
+use_locking: If `True`, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c,
bool sparse) {
ShapeHandle unused;