diff options
Diffstat (limited to 'tensorflow/python/ops/state_ops.py')
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 35fc1226ec..d556d11a1b 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -470,3 +470,57 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), name=name)) + + +@tf_export("scatter_sub") +def scatter_sub(ref, indices, updates, use_locking=False, name=None): + r"""Subtracts sparse updates to a variable reference. + + ```python + # Scalar indices + ref[indices, ...] -= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] -= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + ``` + + This operation outputs `ref` after the update is done. + This makes it easier to chain operations that need to use the reset value. + + Duplicate entries are handled correctly: if multiple `indices` reference + the same location, their (negated) contributions add. + + Requires `updates.shape = indices.shape + ref.shape[1:]` or + `updates.shape = []`. + + <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" + src="https://www.tensorflow.org/images/ScatterSub.png" alt> + </div> + + Args: + ref: A mutable `Tensor`. Must be one of the following types: `float32`, + `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, + `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, + `uint32`, `uint64`. Should be from a `Variable` node. + indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A tensor of indices into the first dimension of `ref`. + updates: A `Tensor`. Must have the same type as `ref`. + A tensor of updated values to subtract from `ref`. + use_locking: An optional `bool`. Defaults to `False`. + If True, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. + name: A name for the operation (optional). + + Returns: + A mutable `Tensor`. Has the same type as `ref`. + """ + if ref.dtype._is_ref_dtype: + return gen_state_ops.scatter_sub(ref, indices, updates, + use_locking=use_locking, name=name) + return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)) |