diff options
Diffstat (limited to 'tensorflow/core/ops/state_ops.cc')
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index cc0c652107..6de3f97548 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -348,6 +348,86 @@ use_locking: If True, the subtraction will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); +REGISTER_OP("ScatterMul") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape) + .Doc(R"doc( +Multiplies sparse updates into a variable reference. + +This operation computes + + # Scalar indices + ref[indices, ...] *= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] *= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to multiply to `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the operation will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ScatterDiv") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape) + .Doc(R"doc( +Divides a variable reference by sparse updates. + +This operation computes + + # Scalar indices + ref[indices, ...] /= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] /= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions divide. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of values that `ref` is divided by. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the operation will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + REGISTER_OP("CountUpTo") .Input("ref: Ref(T)") .Output("output: T") |