aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/state_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/state_ops.cc')
-rw-r--r--tensorflow/core/ops/state_ops.cc80
1 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index cc0c652107..6de3f97548 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -348,6 +348,86 @@ use_locking: If True, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
+REGISTER_OP("ScatterMul")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape)
+ .Doc(R"doc(
+Multiplies sparse updates into a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] *= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] *= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to multiply to `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the operation will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterDiv")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape)
+ .Doc(R"doc(
+Divides a variable reference by sparse updates.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] /= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] /= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions divide.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of values that `ref` is divided by.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the operation will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
REGISTER_OP("CountUpTo")
.Input("ref: Ref(T)")
.Output("output: T")