aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 16:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:21:07 -0700
commit084c10784887d7c4d467416430626cf7eb333cb8 (patch)
tree69727c1d14ddbb97fc74b24f69ce2381125ee6c9
parent95a87277174f9fc49b4b5d9c1edbbd149bd0274c (diff)
Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations.
PiperOrigin-RevId: 190289664
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt60
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt60
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt4
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc81
-rw-r--r--tensorflow/core/kernels/scatter_functor.cc27
-rw-r--r--tensorflow/core/kernels/scatter_functor.h170
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.cc9
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.h108
-rw-r--r--tensorflow/core/kernels/scatter_op.cc126
-rw-r--r--tensorflow/core/kernels/scatter_op_gpu.cu.cc9
-rw-r--r--tensorflow/core/kernels/scatter_op_test.cc26
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc92
-rw-r--r--tensorflow/core/ops/state_ops.cc25
-rw-r--r--tensorflow/docs_src/api_guides/python/state_ops.md2
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py215
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py145
-rw-r--r--tensorflow/python/ops/standard_ops.py2
-rw-r--r--tensorflow/python/ops/state_ops.py2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt8
34 files changed, 1261 insertions, 153 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt
index 9e0de08267..4eb6eb4e4d 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt
@@ -34,7 +34,7 @@ This operation computes
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions add.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt
new file mode 100644
index 0000000000..47148f7b03
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterDiv"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Divides sparse updates into the variable referenced by `resource`."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] /= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] /= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt
new file mode 100644
index 0000000000..71f06d9a43
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterMax"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Reduces sparse updates into the variable referenced by `resource` using the `max` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = max(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions are combined.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt
new file mode 100644
index 0000000000..08e40ee2a8
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterMin"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Reduces sparse updates into the variable referenced by `resource` using the `min` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = min(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions are combined.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt
new file mode 100644
index 0000000000..5c63549d81
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterMul"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Multiplies sparse updates into the variable referenced by `resource`."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] *= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] *= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt
new file mode 100644
index 0000000000..e71e60cbee
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterSub"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Subtracts sparse updates from the variable referenced by `resource`."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] -= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] -= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt
index 4b5201f025..9da9d09ea6 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt
@@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions add.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt
index 771cf0b591..8e99718c7e 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt
@@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions divide.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt
new file mode 100644
index 0000000000..7b52dad4a1
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt
@@ -0,0 +1,60 @@
+op {
+ graph_op_name: "ScatterMax"
+ in_arg {
+ name: "ref"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to reduce into `ref`.
+END
+ }
+ out_arg {
+ name: "output_ref"
+ description: <<END
+= Same as `ref`. Returned as a convenience for operations that want
+to use the updated values after the update is done.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+If True, the update will be protected by a lock;
+otherwise the behavior is undefined, but may exhibit less contention.
+END
+ }
+ summary: "Reduces sparse updates into a variable reference using the `max` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = max(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions combine.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt
new file mode 100644
index 0000000000..721ac0ff35
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt
@@ -0,0 +1,60 @@
+op {
+ graph_op_name: "ScatterMin"
+ in_arg {
+ name: "ref"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to reduce into `ref`.
+END
+ }
+ out_arg {
+ name: "output_ref"
+ description: <<END
+= Same as `ref`. Returned as a convenience for operations that want
+to use the updated values after the update is done.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+If True, the update will be protected by a lock;
+otherwise the behavior is undefined, but may exhibit less contention.
+END
+ }
+ summary: "Reduces sparse updates into a variable reference using the `min` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = min(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions combine.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt
index a51f571b00..b9e293ba9e 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt
@@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions multiply.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt
index c0d3a4a133..d12b3e68c2 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt
@@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their (negated) contributions add.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterSub.png" alt>
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
index c44dbbd233..4804908afc 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
@@ -54,7 +54,7 @@ If values in `ref` is to be updated more than once, because there are
duplicate entries in `indices`, the order at which the updates happen
for each value is undefined.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt
new file mode 100644
index 0000000000..56b5a46d10
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterDiv"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt
new file mode 100644
index 0000000000..8119bcc6c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterMax"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt
new file mode 100644
index 0000000000..d874aef3fe
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterMin"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt
new file mode 100644
index 0000000000..365a37fa0d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterMul"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt
new file mode 100644
index 0000000000..72dc5bf889
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterSub"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index aecad0185f..e134e476f6 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -619,22 +619,35 @@ class ResourceScatterUpdateOp : public OpKernel {
if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params->flat_outer_dims<T>();
- int64 num_updates = updates.NumElements();
- OP_REQUIRES(c, num_updates % N == 0,
- errors::InvalidArgument(
- "shape of indices (", indices.shape().DebugString(),
- ") is not compatible with the shape of updates (",
- updates.shape().DebugString(), ")"));
- auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
-
- functor::ScatterFunctor<Device, T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<Device>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i),
- " = ", indices_flat(bad_i), " is not in [0, ",
- params->dim_size(0), ")"));
+ if (TensorShapeUtils::IsScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+
+ functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params->dim_size(0), ")"));
+ } else {
+ int64 num_updates = updates.NumElements();
+ OP_REQUIRES(c, num_updates % N == 0,
+ errors::InvalidArgument(
+ "shape of indices (", indices.shape().DebugString(),
+ ") is not compatible with the shape of updates (",
+ updates.shape().DebugString(), ")"));
+ auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
+
+ functor::ScatterFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params->dim_size(0), ")"));
+ }
}
}
};
@@ -652,35 +665,51 @@ class ResourceScatterUpdateOp : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-// TODO(apassos) add the other types here.
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
+ scatter_op::UpdateOp::SUB); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
+ scatter_op::UpdateOp::MUL); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
+ scatter_op::UpdateOp::DIV); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
+#define REGISTER_SCATTER_MINMAX(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
+ scatter_op::UpdateOp::MIN); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
+ scatter_op::UpdateOp::MAX);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, CPU);
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, GPU);
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
#endif // GOOGLE_CUDA
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
#undef REGISTER_SCATTER_KERNEL
#undef REGISTER_SCATTER_KERNEL_INDEX
diff --git a/tensorflow/core/kernels/scatter_functor.cc b/tensorflow/core/kernels/scatter_functor.cc
index 7eba82899f..cf5408123f 100644
--- a/tensorflow/core/kernels/scatter_functor.cc
+++ b/tensorflow/core/kernels/scatter_functor.cc
@@ -26,21 +26,30 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// Forward declarations of the functor specializations for GPU.
-#define DECLARE_GPU_SPECS_OP(T, Index, op) \
- template <> \
- Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
- OpKernelContext* c, const GPUDevice& d, \
- typename TTypes<T>::Matrix params, \
- typename TTypes<T>::ConstMatrix updates, \
- typename TTypes<Index>::ConstFlat indices); \
- extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
+#define DECLARE_GPU_SPECS_OP(T, Index, op) \
+ template <> \
+ Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T>::Matrix params, \
+ typename TTypes<T>::ConstMatrix updates, \
+ typename TTypes<Index>::ConstFlat indices); \
+ extern template struct ScatterFunctor<GPUDevice, T, Index, op>; \
+ template <> \
+ Index ScatterScalarFunctor<GPUDevice, T, Index, op>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T>::Matrix params, \
+ const typename TTypes<T>::ConstScalar update, \
+ typename TTypes<Index>::ConstFlat indices); \
+ extern template struct ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index 079f15e101..52666645bf 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -18,6 +18,8 @@ limitations under the License.
#include <type_traits>
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/types.h"
@@ -33,7 +35,7 @@ typedef Eigen::SyclDevice SYCLDevice;
namespace scatter_op {
-enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
+enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX };
namespace internal {
@@ -45,6 +47,10 @@ struct Assign<scatter_op::UpdateOp::ASSIGN> {
static void Run(Params p, Update u) {
p = u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p.setConstant(u);
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::ADD> {
@@ -52,6 +58,10 @@ struct Assign<scatter_op::UpdateOp::ADD> {
static void Run(Params p, Update u) {
p += u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p + u;
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::SUB> {
@@ -59,6 +69,10 @@ struct Assign<scatter_op::UpdateOp::SUB> {
static void Run(Params p, Update u) {
p -= u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p + static_cast<Update>(-u);
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::MUL> {
@@ -66,6 +80,10 @@ struct Assign<scatter_op::UpdateOp::MUL> {
static void Run(Params p, Update u) {
p *= u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p * u;
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::DIV> {
@@ -73,6 +91,34 @@ struct Assign<scatter_op::UpdateOp::DIV> {
static void Run(Params p, Update u) {
p /= u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p / u;
+ }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::MIN> {
+ // This method requires that Params and Update are tensor types.
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p = p.cwiseMin(u);
+ }
+ // Same thing, but for Update being a scalar type.
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p.cwiseMin(u);
+ }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::MAX> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p = p.cwiseMax(u);
+ }
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p.cwiseMax(u);
+ }
};
#ifdef TENSORFLOW_USE_SYCL
@@ -117,6 +163,22 @@ struct AssignSYCL<scatter_op::UpdateOp::DIV> {
p.device(d) = p / u;
}
};
+
+template <>
+struct AssignSYCL<scatter_op::UpdateOp::MIN> {
+ template <typename Device, typename Params, typename Update>
+ static void Run(Device d, Params p, Update u) {
+ p.device(d) = p.cwiseMin(u);
+ }
+};
+
+template <>
+struct AssignSYCL<scatter_op::UpdateOp::MAX> {
+ template <typename Device, typename Params, typename Update>
+ static void Run(Device d, Params p, Update u) {
+ p.device(d) = p.cwiseMax(u);
+ }
+};
#endif // TENSORFLOW_USE_SYCL
} // namespace internal
@@ -241,6 +303,112 @@ struct ScatterFunctorSYCL {
};
#endif // TENSORFLOW_USE_SYCL
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices);
+};
+
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorBase {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::Assign<op>::RunScalar(
+ params.template chip<0>(index), update());
+ }
+ return -1;
+ }
+};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
+ Index operator()(OpKernelContext* c, const SYCLDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::AssignSYCL<op>::RunScalar(
+ d, params.template chip<0>(index), update);
+ }
+ return -1;
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
+template <typename T, typename Index>
+struct ScatterScalarFunctorBase<CPUDevice, T, Index,
+ scatter_op::UpdateOp::ASSIGN> {
+ Index operator()(OpKernelContext* c, const CPUDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar(
+ params.template chip<0>(index), update());
+ }
+ return -1;
+ }
+};
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor<CPUDevice, T, Index, op>
+ : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorSYCL {
+ Index operator()(OpKernelContext* c, const SYCLDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::Flat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::AssignSYCL<op>::Run(
+ d, params.template chip<0>(index), update());
+ }
+ return -1;
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
index 52972997cc..59911bf0d2 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
@@ -23,15 +23,18 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_SPECS_OP(T, Index, op) \
- template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+#define DEFINE_GPU_SPECS_OP(T, Index, op) \
+ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
+ template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
index be18658543..70809e4dcf 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
@@ -29,12 +29,53 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
+namespace scatter_op_gpu {
+
+template <typename T, scatter_op::UpdateOp op>
+struct ScatterOpKernelBody;
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> {
+ __device__ void operator()(T* dest, T src) const { *dest = src; }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); }
+};
+
template <typename T, typename Index, scatter_op::UpdateOp op>
__global__ void ScatterOpCustomKernel(T* params, const T* updates,
const Index* indices,
Index first_dim_size, Index updates_size,
Index indices_size) {
Index update_block = updates_size / indices_size;
+ ScatterOpKernelBody<T, op> body;
CUDA_1D_KERNEL_LOOP(i, updates_size) {
int indices_i = i / update_block;
int updates_i = i;
@@ -44,31 +85,33 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates,
continue;
}
int params_i = param_first_index * update_block + (i % update_block);
- switch (op) {
- case scatter_op::UpdateOp::ASSIGN: {
- params[params_i] = ldg(updates + updates_i);
- break;
- }
- case scatter_op::UpdateOp::ADD: {
- CudaAtomicAdd(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::SUB: {
- CudaAtomicSub(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::MUL: {
- CudaAtomicMul(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::DIV: {
- CudaAtomicDiv(params + params_i, ldg(updates + updates_i));
- break;
- }
+ body(&params[params_i], ldg(updates + updates_i));
+ }
+}
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+__global__ void ScatterScalarOpCustomKernel(T* params, const T* update,
+ const Index* indices,
+ Index first_dim_size,
+ Index indices_size,
+ Index synthesized_updates_size) {
+ Index update_block = synthesized_updates_size / indices_size;
+ ScatterOpKernelBody<T, op> body;
+ CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) {
+ int indices_i = i / update_block;
+ int param_first_index = indices[indices_i];
+ const T update_val = *update;
+ if (!(param_first_index >= 0 && param_first_index < first_dim_size)) {
+ // Ignore indices that are out of range.
+ continue;
}
+ int params_i = param_first_index * update_block + (i % update_block);
+ body(&params[params_i], update_val);
}
}
+} // namespace scatter_op_gpu
+
namespace functor {
// Specialization for a GPU device.
template <typename T, typename Index, scatter_op::UpdateOp op>
@@ -84,7 +127,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
const Index indices_size = indices.size();
const Index updates_size = updates.size();
CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d);
- ScatterOpCustomKernel<T, Index, op>
+ scatter_op_gpu::ScatterOpCustomKernel<T, Index, op>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
params.data(), updates.data(), indices.data(), first_dim_size,
updates_size, indices_size);
@@ -92,6 +135,27 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
}
};
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
+ Index operator()(OpKernelContext* c, const GPUDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // TODO(b/31801742): Implement indices range check. The hardest part is
+ // with returning a value after the range check, as we do not want to do
+ // device to host memcpy during a stream.
+ const Index first_dim_size = params.dimension(0);
+ const Index indices_size = indices.size();
+ const Index synthesized_updates_size = indices_size * params.dimension(1);
+ CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d);
+ scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ params.data(), update.data(), indices.data(), first_dim_size,
+ indices_size, synthesized_updates_size);
+ return -1;
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 282165349f..0fbde764d5 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -38,6 +38,7 @@ typedef Eigen::SyclDevice SYCLDevice;
// Check whether updates.shape = indices.shape + params.shape[1:]
static bool ValidShapes(const Tensor& params, const Tensor& updates,
const Tensor& indices) {
+ if (updates.dims() == 0) return true;
if (updates.dims() != indices.dims() + params.dims() - 1) return false;
for (int d = 0; d < indices.dims(); d++) {
if (updates.dim_size(d) != indices.dim_size(d)) {
@@ -61,11 +62,11 @@ static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
params.shape().DebugString()));
OP_REQUIRES(
c, ValidShapes(params, updates, indices),
- errors::InvalidArgument(
- "Must have updates.shape = indices.shape + params.shape[1:], got ",
- "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
- indices.shape().DebugString(), ", params.shape ",
- params.shape().DebugString()));
+ errors::InvalidArgument("Must have updates.shape = indices.shape + "
+ "params.shape[1:] or updates.shape = [], got ",
+ "updates.shape ", updates.shape().DebugString(),
+ ", indices.shape ", indices.shape().DebugString(),
+ ", params.shape ", params.shape().DebugString()));
}
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
@@ -122,16 +123,31 @@ class ScatterUpdateOp : public OpKernel {
if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params.flat_outer_dims<T>();
- auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
- functor::ScatterFunctor<Device, T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<Device>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(
- c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i), " = ",
- indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+ if (TensorShapeUtils::IsScalar(updates.shape()) ||
+ IsLegacyScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+ functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ } else {
+ auto updates_flat =
+ updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+ functor::ScatterFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ }
}
}
};
@@ -195,16 +211,31 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
auto indices_flat = indices_host.flat<Index>();
auto params_flat = params.flat_outer_dims<T>();
- auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
- functor::ScatterFunctorSYCL<T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(
- c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i), " = ",
- indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+ if (TensorShapeUtils::IsScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+
+ functor::ScatterScalarFunctorSYCL<T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ } else {
+ auto updates_flat =
+ updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+ functor::ScatterFunctorSYCL<T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ }
}
}
};
@@ -221,54 +252,71 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
+#define REGISTER_SCATTER_MINMAX(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX);
+
#define REGISTER_SCATTER_UPDATE(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, CPU);
+
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, GPU);
+
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#endif // GOOGLE_CUDA
// Registers GPU kernels.
#if TENSORFLOW_USE_SYCL
-#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, SYCL);
+#define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, SYCL);
+
+#define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL);
#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
-#undef REGISTER_SCATTER_ARITHEMTIC_SYCL
+#undef REGISTER_SCATTER_ARITHMETIC_SYCL
+#undef REGISTER_SCATTER_MINMAX_SYCL
#undef REGISTER_SCATTER_UPDATE_SYCL
#endif // TENSORFLOW_USE_SYCL
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
-#undef REGISTER_SCATTER_ARITHEMTIC_GPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC_GPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
+#undef REGISTER_SCATTER_MINMAX_GPU
#undef REGISTER_SCATTER_UPDATE
#undef REGISTER_SCATTER_UPDATE_CPU
#undef REGISTER_SCATTER_UPDATE_GPU
diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
index 0b43704846..0df329310f 100644
--- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
@@ -24,15 +24,18 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
// Instantiates functor specializations for GPU.
-#define DEFINE_GPU_SPECS_OP(T, Index, op) \
- template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+#define DEFINE_GPU_SPECS_OP(T, Index, op) \
+ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
+ template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc
index 0b8645a2ae..5b3537b94c 100644
--- a/tensorflow/core/kernels/scatter_op_test.cc
+++ b/tensorflow/core/kernels/scatter_op_test.cc
@@ -185,7 +185,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Must have updates.shape = indices.shape + "
- "params.shape[1:], got "))
+ "params.shape[1:] or updates.shape = [], got "))
<< s;
}
@@ -202,7 +202,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Must have updates.shape = indices.shape + "
- "params.shape[1:], got "))
+ "params.shape[1:] or updates.shape = [], got "))
<< s;
}
@@ -219,7 +219,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Must have updates.shape = indices.shape + "
- "params.shape[1:], got "))
+ "params.shape[1:] or updates.shape = [], got "))
<< s;
}
@@ -300,6 +300,20 @@ static void BM_ScatterDivInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv");
}
+static void BM_ScatterMinInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMin");
+}
+static void BM_ScatterMinInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMin");
+}
+
+static void BM_ScatterMaxInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMax");
+}
+static void BM_ScatterMaxInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMax");
+}
+
BENCHMARK(BM_ScatterUpdateInt32)
->Arg(1)
->Arg(10)
@@ -332,5 +346,11 @@ BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMinInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMinInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+BENCHMARK(BM_ScatterMaxInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMaxInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 0d8cf78cc2..3d0a6c2157 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -167,27 +167,75 @@ REGISTER_OP("ResourceGather")
return Status::OK();
});
+namespace {
+
+Status ResourceScatterUpdateShape(InferenceContext* c) {
+ ShapeAndType handle_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type));
+ ShapeHandle var_shape = handle_shape_and_type.shape;
+ ShapeHandle indices_shape = c->input(1);
+
+ ShapeHandle unused_updates_shape;
+ ShapeHandle concat;
+ ShapeHandle var_subshape;
+ TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
+ TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
+ TF_RETURN_IF_ERROR(
+ InferenceContext::Rank(c->input(2)) == 0
+ ? Status::OK()
+ : c->Merge(c->input(2), concat, &unused_updates_shape));
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("ResourceScatterAdd")
.Input("resource: resource")
.Input("indices: Tindices")
.Input("updates: dtype")
.Attr("dtype: numbertype")
.Attr("Tindices: {int32, int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeAndType handle_shape_and_type;
- TF_RETURN_IF_ERROR(
- ValidateVariableResourceHandle(c, &handle_shape_and_type));
- ShapeHandle var_shape = handle_shape_and_type.shape;
- ShapeHandle indices_shape = c->input(1);
+ .SetShapeFn(ResourceScatterUpdateShape);
- ShapeHandle unused_updates_shape;
- ShapeHandle concat;
- ShapeHandle var_subshape;
- TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
- TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
- return Status::OK();
- });
+REGISTER_OP("ResourceScatterSub")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMul")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterDiv")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMin")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMax")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
REGISTER_OP("ResourceScatterUpdate")
.Input("resource: resource")
@@ -195,21 +243,7 @@ REGISTER_OP("ResourceScatterUpdate")
.Input("updates: dtype")
.Attr("dtype: type")
.Attr("Tindices: {int32, int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeAndType handle_shape_and_type;
- TF_RETURN_IF_ERROR(
- ValidateVariableResourceHandle(c, &handle_shape_and_type));
- ShapeHandle var_shape = handle_shape_and_type.shape;
- ShapeHandle indices_shape = c->input(1);
-
- ShapeHandle unused_updates_shape;
- ShapeHandle concat;
- ShapeHandle var_subshape;
- TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
- TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
- return Status::OK();
- });
+ .SetShapeFn(ResourceScatterUpdateShape);
REGISTER_OP("MutexV2")
.Attr("container: string = ''")
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 7a524b60c0..664f52452e 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -122,7 +122,10 @@ Status ScatterUpdateShape(InferenceContext* c) {
ShapeHandle var_subshape;
TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
+ TF_RETURN_IF_ERROR(
+ InferenceContext::Rank(c->input(2)) == 0
+ ? Status::OK()
+ : c->Merge(c->input(2), concat, &unused_updates_shape));
c->set_output(0, var_shape);
return Status::OK();
@@ -180,6 +183,26 @@ REGISTER_OP("ScatterDiv")
.Attr("use_locking: bool = false")
.SetShapeFn(ScatterUpdateShape);
+REGISTER_OP("ScatterMin")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape);
+
+REGISTER_OP("ScatterMax")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape);
+
REGISTER_OP("ScatterNdUpdate")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
diff --git a/tensorflow/docs_src/api_guides/python/state_ops.md b/tensorflow/docs_src/api_guides/python/state_ops.md
index 0d612ee0c7..ec2d877386 100644
--- a/tensorflow/docs_src/api_guides/python/state_ops.md
+++ b/tensorflow/docs_src/api_guides/python/state_ops.md
@@ -83,6 +83,8 @@ automatically by the optimizers in most cases.
* @{tf.scatter_sub}
* @{tf.scatter_mul}
* @{tf.scatter_div}
+* @{tf.scatter_min}
+* @{tf.scatter_max}
* @{tf.scatter_nd_update}
* @{tf.scatter_nd_add}
* @{tf.scatter_nd_sub}
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 563eeff2a6..742564f9bf 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -185,6 +185,204 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterSub(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_sub(handle, [0],
+ constant_op.constant(
+ [[2]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[-1]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMul(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_mul(handle, [0],
+ constant_op.constant(
+ [[5]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[5]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterDiv(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_div(handle, [0],
+ constant_op.constant(
+ [[3]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[2]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMin(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_min(handle, [0],
+ constant_op.constant(
+ [[3]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMax(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_max(handle, [0],
+ constant_op.constant(
+ [[3]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[6]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterAddScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_add(handle, [0],
+ constant_op.constant(
+ 2,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterSubScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_sub(handle, [0],
+ constant_op.constant(
+ 2,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[-1]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMulScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_mul(handle, [0],
+ constant_op.constant(
+ 5,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[5]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterDivScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_div(handle, [0],
+ constant_op.constant(
+ 3,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[2]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMinScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_min(handle, [0],
+ constant_op.constant(
+ 3,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMaxScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_max(handle, [0],
+ constant_op.constant(
+ 3,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[6]])
+
def testScatterUpdateString(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.string, shape=[1, 1])
@@ -196,6 +394,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
compat.as_bytes("b"))
+ def testScatterUpdateStringScalar(self):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.string, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [["a"]],
+ dtype=dtypes.string)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_update(handle, [0],
+ constant_op.constant(
+ "b",
+ dtype=dtypes.string)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
+ self.assertEqual(
+ compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
+
# TODO(alive): get this to work in Eager mode.
def testGPU(self):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index 7cdf11d884..c70a4ffce7 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -38,38 +38,100 @@ def _NumpyAdd(ref, indices, updates):
ref[indx] += updates[i]
+def _NumpyAddScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] += update
+
+
def _NumpySub(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] -= updates[i]
+def _NumpySubScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] -= update
+
+
def _NumpyMul(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] *= updates[i]
+def _NumpyMulScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] *= update
+
+
def _NumpyDiv(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] /= updates[i]
+def _NumpyDivScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] /= update
+
+
+def _NumpyMin(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ ref[indx] = np.minimum(ref[indx], updates[i])
+
+
+def _NumpyMinScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] = np.minimum(ref[indx], update)
+
+
+def _NumpyMax(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ ref[indx] = np.maximum(ref[indx], updates[i])
+
+
+def _NumpyMaxScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] = np.maximum(ref[indx], update)
+
+
def _NumpyUpdate(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] = updates[i]
+def _NumpyUpdateScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] = update
+
+
_TF_OPS_TO_NUMPY = {
state_ops.scatter_update: _NumpyUpdate,
state_ops.scatter_add: _NumpyAdd,
state_ops.scatter_sub: _NumpySub,
state_ops.scatter_mul: _NumpyMul,
state_ops.scatter_div: _NumpyDiv,
+ state_ops.scatter_min: _NumpyMin,
+ state_ops.scatter_max: _NumpyMax,
+}
+
+_TF_OPS_TO_NUMPY_SCALAR = {
+ state_ops.scatter_update: _NumpyUpdateScalar,
+ state_ops.scatter_add: _NumpyAddScalar,
+ state_ops.scatter_sub: _NumpySubScalar,
+ state_ops.scatter_mul: _NumpyMulScalar,
+ state_ops.scatter_div: _NumpyDivScalar,
+ state_ops.scatter_min: _NumpyMinScalar,
+ state_ops.scatter_max: _NumpyMaxScalar,
}
class ScatterTest(test.TestCase):
- def _VariableRankTest(self, tf_scatter, vtype, itype, repeat_indices=False):
+ def _VariableRankTest(self,
+ tf_scatter,
+ vtype,
+ itype,
+ repeat_indices=False,
+ updates_are_scalar=False):
np.random.seed(8)
with self.test_session(use_gpu=True):
for indices_shape in (), (2,), (3, 7), (3, 4, 7):
@@ -89,8 +151,11 @@ class ScatterTest(test.TestCase):
indices[np.random.randint(size // 2)])
np.random.shuffle(indices)
indices = indices.reshape(indices_shape)
- updates = _AsType(
- np.random.randn(*(indices_shape + extra_shape)), vtype)
+ if updates_are_scalar:
+ updates = _AsType(np.random.randn(), vtype)
+ else:
+ updates = _AsType(
+ np.random.randn(*(indices_shape + extra_shape)), vtype)
# Clips small values to avoid division by zero.
def clip_small_values(x):
@@ -101,7 +166,10 @@ class ScatterTest(test.TestCase):
# Scatter via numpy
new = old.copy()
- np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
+ if updates_are_scalar:
+ np_scatter = _TF_OPS_TO_NUMPY_SCALAR[tf_scatter]
+ else:
+ np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
ref = variables.Variable(old)
@@ -109,25 +177,35 @@ class ScatterTest(test.TestCase):
tf_scatter(ref, indices, updates).eval()
self.assertAllClose(ref.eval(), new)
- def _VariableRankTests(self, tf_scatter, repeat_indices=False):
+ def _VariableRankTests(self,
+ tf_scatter,
+ repeat_indices=False,
+ updates_are_scalar=False):
for vtype in (np.float32, np.float64):
for itype in (np.int32, np.int64):
- self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices)
+ self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
+ updates_are_scalar)
def testVariableRankUpdate(self):
- self._VariableRankTests(state_ops.scatter_update)
+ self._VariableRankTests(state_ops.scatter_update, False)
def testVariableRankAdd(self):
- self._VariableRankTests(state_ops.scatter_add)
+ self._VariableRankTests(state_ops.scatter_add, False)
def testVariableRankSub(self):
- self._VariableRankTests(state_ops.scatter_sub)
+ self._VariableRankTests(state_ops.scatter_sub, False)
def testVariableRankMul(self):
- self._VariableRankTests(state_ops.scatter_mul)
+ self._VariableRankTests(state_ops.scatter_mul, False)
def testVariableRankDiv(self):
- self._VariableRankTests(state_ops.scatter_div)
+ self._VariableRankTests(state_ops.scatter_div, False)
+
+ def testVariableRankMin(self):
+ self._VariableRankTests(state_ops.scatter_min, False)
+
+ def testVariableRankMax(self):
+ self._VariableRankTests(state_ops.scatter_max, False)
def testRepeatIndicesAdd(self):
self._VariableRankTests(state_ops.scatter_add, True)
@@ -141,6 +219,51 @@ class ScatterTest(test.TestCase):
def testRepeatIndicesDiv(self):
self._VariableRankTests(state_ops.scatter_div, True)
+ def testRepeatIndicesMin(self):
+ self._VariableRankTests(state_ops.scatter_min, True)
+
+ def testRepeatIndicesMax(self):
+ self._VariableRankTests(state_ops.scatter_max, True)
+
+ def testVariableRankUpdateScalar(self):
+ self._VariableRankTests(state_ops.scatter_update, False, True)
+
+ def testVariableRankAddScalar(self):
+ self._VariableRankTests(state_ops.scatter_add, False, True)
+
+ def testVariableRankSubScalar(self):
+ self._VariableRankTests(state_ops.scatter_sub, False, True)
+
+ def testVariableRankMulScalar(self):
+ self._VariableRankTests(state_ops.scatter_mul, False, True)
+
+ def testVariableRankDivScalar(self):
+ self._VariableRankTests(state_ops.scatter_div, False, True)
+
+ def testVariableRankMinScalar(self):
+ self._VariableRankTests(state_ops.scatter_min, False, True)
+
+ def testVariableRankMaxScalar(self):
+ self._VariableRankTests(state_ops.scatter_max, False, True)
+
+ def testRepeatIndicesAddScalar(self):
+ self._VariableRankTests(state_ops.scatter_add, True, True)
+
+ def testRepeatIndicesSubScalar(self):
+ self._VariableRankTests(state_ops.scatter_sub, True, True)
+
+ def testRepeatIndicesMulScalar(self):
+ self._VariableRankTests(state_ops.scatter_mul, True, True)
+
+ def testRepeatIndicesDivScalar(self):
+ self._VariableRankTests(state_ops.scatter_div, True, True)
+
+ def testRepeatIndicesMinScalar(self):
+ self._VariableRankTests(state_ops.scatter_min, True, True)
+
+ def testRepeatIndicesMaxScalar(self):
+ self._VariableRankTests(state_ops.scatter_max, True, True)
+
def testBooleanScatterUpdate(self):
if not test.is_gpu_available():
with self.test_session(use_gpu=False) as session:
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 230b7ef937..e90ff0746a 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -80,6 +80,8 @@ from tensorflow.python.ops.state_ops import scatter_add
from tensorflow.python.ops.state_ops import scatter_div
from tensorflow.python.ops.state_ops import scatter_mul
from tensorflow.python.ops.state_ops import scatter_sub
+from tensorflow.python.ops.state_ops import scatter_min
+from tensorflow.python.ops.state_ops import scatter_max
from tensorflow.python.ops.state_ops import scatter_update
from tensorflow.python.ops.state_ops import scatter_nd_add
from tensorflow.python.ops.state_ops import scatter_nd_sub
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index c3ad5831b4..01fc3182bc 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -63,6 +63,8 @@
@@scatter_nd_update
@@scatter_sub
@@scatter_update
+@@scatter_min
+@@scatter_max
@@sparse_mask
@@tables_initializer
@@trainable_variables
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 55b82dd765..937044aece 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1689,6 +1689,14 @@ tf_module {
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "scatter_max"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "scatter_min"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
name: "scatter_mul"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}