aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/state_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/state_ops.py')
-rw-r--r--tensorflow/python/ops/state_ops.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 35fc1226ec..d556d11a1b 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -470,3 +470,57 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
+
+
+@tf_export("scatter_sub")
+def scatter_sub(ref, indices, updates, use_locking=False, name=None):
+ r"""Subtracts sparse updates to a variable reference.
+
+ ```python
+ # Scalar indices
+ ref[indices, ...] -= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] -= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+ ```
+
+ This operation outputs `ref` after the update is done.
+ This makes it easier to chain operations that need to use the reset value.
+
+ Duplicate entries are handled correctly: if multiple `indices` reference
+ the same location, their (negated) contributions add.
+
+ Requires `updates.shape = indices.shape + ref.shape[1:]` or
+ `updates.shape = []`.
+
+ <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%"
+ src="https://www.tensorflow.org/images/ScatterSub.png" alt>
+ </div>
+
+ Args:
+ ref: A mutable `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
+ `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
+ `uint32`, `uint64`. Should be from a `Variable` node.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into the first dimension of `ref`.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to subtract from `ref`.
+ use_locking: An optional `bool`. Defaults to `False`.
+ If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ A mutable `Tensor`. Has the same type as `ref`.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_sub(ref, indices, updates,
+ use_locking=use_locking, name=name)
+ return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))