From 084c10784887d7c4d467416430626cf7eb333cb8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 23 Mar 2018 16:00:14 -0700 Subject: Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations. PiperOrigin-RevId: 190289664 --- .../base_api/api_def_ResourceScatterAdd.pbtxt | 2 +- .../base_api/api_def_ResourceScatterDiv.pbtxt | 43 +++++ .../base_api/api_def_ResourceScatterMax.pbtxt | 43 +++++ .../base_api/api_def_ResourceScatterMin.pbtxt | 43 +++++ .../base_api/api_def_ResourceScatterMul.pbtxt | 43 +++++ .../base_api/api_def_ResourceScatterSub.pbtxt | 43 +++++ .../core/api_def/base_api/api_def_ScatterAdd.pbtxt | 2 +- .../core/api_def/base_api/api_def_ScatterDiv.pbtxt | 2 +- .../core/api_def/base_api/api_def_ScatterMax.pbtxt | 60 ++++++ .../core/api_def/base_api/api_def_ScatterMin.pbtxt | 60 ++++++ .../core/api_def/base_api/api_def_ScatterMul.pbtxt | 2 +- .../core/api_def/base_api/api_def_ScatterSub.pbtxt | 2 +- .../api_def/base_api/api_def_ScatterUpdate.pbtxt | 2 +- .../python_api/api_def_ResourceScatterDiv.pbtxt | 4 + .../python_api/api_def_ResourceScatterMax.pbtxt | 4 + .../python_api/api_def_ResourceScatterMin.pbtxt | 4 + .../python_api/api_def_ResourceScatterMul.pbtxt | 4 + .../python_api/api_def_ResourceScatterSub.pbtxt | 4 + tensorflow/core/kernels/resource_variable_ops.cc | 81 +++++--- tensorflow/core/kernels/scatter_functor.cc | 27 ++- tensorflow/core/kernels/scatter_functor.h | 170 +++++++++++++++- tensorflow/core/kernels/scatter_functor_gpu.cu.cc | 9 +- tensorflow/core/kernels/scatter_functor_gpu.cu.h | 108 ++++++++--- tensorflow/core/kernels/scatter_op.cc | 126 ++++++++---- tensorflow/core/kernels/scatter_op_gpu.cu.cc | 9 +- tensorflow/core/kernels/scatter_op_test.cc | 26 ++- tensorflow/core/ops/resource_variable_ops.cc | 92 ++++++--- tensorflow/core/ops/state_ops.cc | 25 ++- tensorflow/docs_src/api_guides/python/state_ops.md | 2 + .../kernel_tests/resource_variable_ops_test.py | 215 +++++++++++++++++++++ tensorflow/python/kernel_tests/scatter_ops_test.py | 145 ++++++++++++-- tensorflow/python/ops/standard_ops.py | 2 + tensorflow/python/ops/state_ops.py | 2 + tensorflow/tools/api/golden/tensorflow.pbtxt | 8 + 34 files changed, 1261 insertions(+), 153 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt index 9e0de08267..4eb6eb4e4d 100644 --- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt @@ -34,7 +34,7 @@ This operation computes Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions add. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt new file mode 100644 index 0000000000..47148f7b03 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterDiv" + in_arg { + name: "resource" + description: < + +
+END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt new file mode 100644 index 0000000000..71f06d9a43 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterMax" + in_arg { + name: "resource" + description: < + + +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt new file mode 100644 index 0000000000..08e40ee2a8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterMin" + in_arg { + name: "resource" + description: < + + +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt new file mode 100644 index 0000000000..5c63549d81 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterMul" + in_arg { + name: "resource" + description: < + + +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt new file mode 100644 index 0000000000..e71e60cbee --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterSub" + in_arg { + name: "resource" + description: < + + +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt index 4b5201f025..9da9d09ea6 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt @@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions add. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt index 771cf0b591..8e99718c7e 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt @@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions divide. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt new file mode 100644 index 0000000000..7b52dad4a1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt @@ -0,0 +1,60 @@ +op { + graph_op_name: "ScatterMax" + in_arg { + name: "ref" + description: < + +
+END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt new file mode 100644 index 0000000000..721ac0ff35 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt @@ -0,0 +1,60 @@ +op { + graph_op_name: "ScatterMin" + in_arg { + name: "ref" + description: < + + +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt index a51f571b00..b9e293ba9e 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt @@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions multiply. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt index c0d3a4a133..d12b3e68c2 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt @@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their (negated) contributions add. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt index c44dbbd233..4804908afc 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt @@ -54,7 +54,7 @@ If values in `ref` is to be updated more than once, because there are duplicate entries in `indices`, the order at which the updates happen for each value is undefined. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt new file mode 100644 index 0000000000..56b5a46d10 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterDiv" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt new file mode 100644 index 0000000000..8119bcc6c6 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterMax" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt new file mode 100644 index 0000000000..d874aef3fe --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterMin" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt new file mode 100644 index 0000000000..365a37fa0d --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterMul" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt new file mode 100644 index 0000000000..72dc5bf889 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterSub" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index aecad0185f..e134e476f6 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -619,22 +619,35 @@ class ResourceScatterUpdateOp : public OpKernel { if (N > 0) { auto indices_flat = indices.flat(); auto params_flat = params->flat_outer_dims(); - int64 num_updates = updates.NumElements(); - OP_REQUIRES(c, num_updates % N == 0, - errors::InvalidArgument( - "shape of indices (", indices.shape().DebugString(), - ") is not compatible with the shape of updates (", - updates.shape().DebugString(), ")")); - auto updates_flat = updates.shaped({N, num_updates / N}); - - functor::ScatterFunctor functor; - const Index bad_i = functor(c, c->template eigen_device(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES(c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), - " = ", indices_flat(bad_i), " is not in [0, ", - params->dim_size(0), ")")); + if (TensorShapeUtils::IsScalar(updates.shape())) { + const auto update = updates.scalar(); + + functor::ScatterScalarFunctor functor; + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params->dim_size(0), ")")); + } else { + int64 num_updates = updates.NumElements(); + OP_REQUIRES(c, num_updates % N == 0, + errors::InvalidArgument( + "shape of indices (", indices.shape().DebugString(), + ") is not compatible with the shape of updates (", + updates.shape().DebugString(), ")")); + auto updates_flat = updates.shaped({N, num_updates / N}); + + functor::ScatterFunctor functor; + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params->dim_size(0), ")")); + } } } }; @@ -652,35 +665,51 @@ class ResourceScatterUpdateOp : public OpKernel { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -// TODO(apassos) add the other types here. -#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ +#define REGISTER_SCATTER_ARITHMETIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \ scatter_op::UpdateOp::ADD); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \ + scatter_op::UpdateOp::SUB); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \ + scatter_op::UpdateOp::MUL); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \ + scatter_op::UpdateOp::DIV); \ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \ scatter_op::UpdateOp::ASSIGN); +#define REGISTER_SCATTER_MINMAX(type, dev) \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \ + scatter_op::UpdateOp::MIN); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \ + scatter_op::UpdateOp::MAX); // Registers CPU kernels. -#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, CPU); +#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, CPU); +#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", scatter_op::UpdateOp::ASSIGN); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, GPU); +#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, GPU); +#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU); #endif // GOOGLE_CUDA -#undef REGISTER_SCATTER_ARITHEMTIC -#undef REGISTER_SCATTER_ARITHEMTIC_CPU +#undef REGISTER_SCATTER_ARITHMETIC +#undef REGISTER_SCATTER_ARITHMETIC_CPU +#undef REGISTER_SCATTER_MINMAX +#undef REGISTER_SCATTER_MINMAX_CPU #undef REGISTER_SCATTER_KERNEL #undef REGISTER_SCATTER_KERNEL_INDEX diff --git a/tensorflow/core/kernels/scatter_functor.cc b/tensorflow/core/kernels/scatter_functor.cc index 7eba82899f..cf5408123f 100644 --- a/tensorflow/core/kernels/scatter_functor.cc +++ b/tensorflow/core/kernels/scatter_functor.cc @@ -26,21 +26,30 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { // Forward declarations of the functor specializations for GPU. -#define DECLARE_GPU_SPECS_OP(T, Index, op) \ - template <> \ - Index ScatterFunctor::operator()( \ - OpKernelContext* c, const GPUDevice& d, \ - typename TTypes::Matrix params, \ - typename TTypes::ConstMatrix updates, \ - typename TTypes::ConstFlat indices); \ - extern template struct ScatterFunctor; +#define DECLARE_GPU_SPECS_OP(T, Index, op) \ + template <> \ + Index ScatterFunctor::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes::Matrix params, \ + typename TTypes::ConstMatrix updates, \ + typename TTypes::ConstFlat indices); \ + extern template struct ScatterFunctor; \ + template <> \ + Index ScatterScalarFunctor::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes::Matrix params, \ + const typename TTypes::ConstScalar update, \ + typename TTypes::ConstFlat indices); \ + extern template struct ScatterScalarFunctor; #define DECLARE_GPU_SPECS_INDEX(T, Index) \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DECLARE_GPU_SPECS(T) \ DECLARE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h index 079f15e101..52666645bf 100644 --- a/tensorflow/core/kernels/scatter_functor.h +++ b/tensorflow/core/kernels/scatter_functor.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/types.h" @@ -33,7 +35,7 @@ typedef Eigen::SyclDevice SYCLDevice; namespace scatter_op { -enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV }; +enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX }; namespace internal { @@ -45,6 +47,10 @@ struct Assign { static void Run(Params p, Update u) { p = u; } + template + static void RunScalar(Params p, Update u) { + p.setConstant(u); + } }; template <> struct Assign { @@ -52,6 +58,10 @@ struct Assign { static void Run(Params p, Update u) { p += u; } + template + static void RunScalar(Params p, Update u) { + p = p + u; + } }; template <> struct Assign { @@ -59,6 +69,10 @@ struct Assign { static void Run(Params p, Update u) { p -= u; } + template + static void RunScalar(Params p, Update u) { + p = p + static_cast(-u); + } }; template <> struct Assign { @@ -66,6 +80,10 @@ struct Assign { static void Run(Params p, Update u) { p *= u; } + template + static void RunScalar(Params p, Update u) { + p = p * u; + } }; template <> struct Assign { @@ -73,6 +91,34 @@ struct Assign { static void Run(Params p, Update u) { p /= u; } + template + static void RunScalar(Params p, Update u) { + p = p / u; + } +}; +template <> +struct Assign { + // This method requires that Params and Update are tensor types. + template + static void Run(Params p, Update u) { + p = p.cwiseMin(u); + } + // Same thing, but for Update being a scalar type. + template + static void RunScalar(Params p, Update u) { + p = p.cwiseMin(u); + } +}; +template <> +struct Assign { + template + static void Run(Params p, Update u) { + p = p.cwiseMax(u); + } + template + static void RunScalar(Params p, Update u) { + p = p.cwiseMax(u); + } }; #ifdef TENSORFLOW_USE_SYCL @@ -117,6 +163,22 @@ struct AssignSYCL { p.device(d) = p / u; } }; + +template <> +struct AssignSYCL { + template + static void Run(Device d, Params p, Update u) { + p.device(d) = p.cwiseMin(u); + } +}; + +template <> +struct AssignSYCL { + template + static void Run(Device d, Params p, Update u) { + p.device(d) = p.cwiseMax(u); + } +}; #endif // TENSORFLOW_USE_SYCL } // namespace internal @@ -241,6 +303,112 @@ struct ScatterFunctorSYCL { }; #endif // TENSORFLOW_USE_SYCL +template +struct ScatterScalarFunctor { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices); +}; + +template +struct ScatterScalarFunctorBase { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::Assign::RunScalar( + params.template chip<0>(index), update()); + } + return -1; + } +}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct ScatterScalarFunctorBase { + Index operator()(OpKernelContext* c, const SYCLDevice& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::AssignSYCL::RunScalar( + d, params.template chip<0>(index), update); + } + return -1; + } +}; +#endif // TENSORFLOW_USE_SYCL + +template +struct ScatterScalarFunctorBase { + Index operator()(OpKernelContext* c, const CPUDevice& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::Assign::RunScalar( + params.template chip<0>(index), update()); + } + return -1; + } +}; + +template +struct ScatterScalarFunctor + : ScatterScalarFunctorBase {}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct ScatterScalarFunctorSYCL { + Index operator()(OpKernelContext* c, const SYCLDevice& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::Flat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::AssignSYCL::Run( + d, params.template chip<0>(index), update()); + } + return -1; + } +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc index 52972997cc..59911bf0d2 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc @@ -23,15 +23,18 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPECS_OP(T, Index, op) \ - template struct functor::ScatterFunctor; +#define DEFINE_GPU_SPECS_OP(T, Index, op) \ + template struct functor::ScatterFunctor; \ + template struct functor::ScatterScalarFunctor; #define DEFINE_GPU_SPECS_INDEX(T, Index) \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DEFINE_GPU_SPECS(T) \ DEFINE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h index be18658543..70809e4dcf 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h @@ -29,12 +29,53 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; +namespace scatter_op_gpu { + +template +struct ScatterOpKernelBody; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* dest, T src) const { *dest = src; } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); } +}; + template __global__ void ScatterOpCustomKernel(T* params, const T* updates, const Index* indices, Index first_dim_size, Index updates_size, Index indices_size) { Index update_block = updates_size / indices_size; + ScatterOpKernelBody body; CUDA_1D_KERNEL_LOOP(i, updates_size) { int indices_i = i / update_block; int updates_i = i; @@ -44,31 +85,33 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates, continue; } int params_i = param_first_index * update_block + (i % update_block); - switch (op) { - case scatter_op::UpdateOp::ASSIGN: { - params[params_i] = ldg(updates + updates_i); - break; - } - case scatter_op::UpdateOp::ADD: { - CudaAtomicAdd(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::SUB: { - CudaAtomicSub(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::MUL: { - CudaAtomicMul(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::DIV: { - CudaAtomicDiv(params + params_i, ldg(updates + updates_i)); - break; - } + body(¶ms[params_i], ldg(updates + updates_i)); + } +} + +template +__global__ void ScatterScalarOpCustomKernel(T* params, const T* update, + const Index* indices, + Index first_dim_size, + Index indices_size, + Index synthesized_updates_size) { + Index update_block = synthesized_updates_size / indices_size; + ScatterOpKernelBody body; + CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) { + int indices_i = i / update_block; + int param_first_index = indices[indices_i]; + const T update_val = *update; + if (!(param_first_index >= 0 && param_first_index < first_dim_size)) { + // Ignore indices that are out of range. + continue; } + int params_i = param_first_index * update_block + (i % update_block); + body(¶ms[params_i], update_val); } } +} // namespace scatter_op_gpu + namespace functor { // Specialization for a GPU device. template @@ -84,7 +127,7 @@ struct ScatterFunctor { const Index indices_size = indices.size(); const Index updates_size = updates.size(); CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d); - ScatterOpCustomKernel + scatter_op_gpu::ScatterOpCustomKernel <<>>( params.data(), updates.data(), indices.data(), first_dim_size, updates_size, indices_size); @@ -92,6 +135,27 @@ struct ScatterFunctor { } }; +template +struct ScatterScalarFunctor { + Index operator()(OpKernelContext* c, const GPUDevice& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // TODO(b/31801742): Implement indices range check. The hardest part is + // with returning a value after the range check, as we do not want to do + // device to host memcpy during a stream. + const Index first_dim_size = params.dimension(0); + const Index indices_size = indices.size(); + const Index synthesized_updates_size = indices_size * params.dimension(1); + CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d); + scatter_op_gpu::ScatterScalarOpCustomKernel + <<>>( + params.data(), update.data(), indices.data(), first_dim_size, + indices_size, synthesized_updates_size); + return -1; + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 282165349f..0fbde764d5 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -38,6 +38,7 @@ typedef Eigen::SyclDevice SYCLDevice; // Check whether updates.shape = indices.shape + params.shape[1:] static bool ValidShapes(const Tensor& params, const Tensor& updates, const Tensor& indices) { + if (updates.dims() == 0) return true; if (updates.dims() != indices.dims() + params.dims() - 1) return false; for (int d = 0; d < indices.dims(); d++) { if (updates.dim_size(d) != indices.dim_size(d)) { @@ -61,11 +62,11 @@ static void DoValidationChecking(OpKernelContext* c, const Tensor& params, params.shape().DebugString())); OP_REQUIRES( c, ValidShapes(params, updates, indices), - errors::InvalidArgument( - "Must have updates.shape = indices.shape + params.shape[1:], got ", - "updates.shape ", updates.shape().DebugString(), ", indices.shape ", - indices.shape().DebugString(), ", params.shape ", - params.shape().DebugString())); + errors::InvalidArgument("Must have updates.shape = indices.shape + " + "params.shape[1:] or updates.shape = [], got ", + "updates.shape ", updates.shape().DebugString(), + ", indices.shape ", indices.shape().DebugString(), + ", params.shape ", params.shape().DebugString())); } template @@ -122,16 +123,31 @@ class ScatterUpdateOp : public OpKernel { if (N > 0) { auto indices_flat = indices.flat(); auto params_flat = params.flat_outer_dims(); - auto updates_flat = updates.shaped({N, updates.NumElements() / N}); - - functor::ScatterFunctor functor; - const Index bad_i = functor(c, c->template eigen_device(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); + + if (TensorShapeUtils::IsScalar(updates.shape()) || + IsLegacyScalar(updates.shape())) { + const auto update = updates.scalar(); + functor::ScatterScalarFunctor functor; + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } else { + auto updates_flat = + updates.shaped({N, updates.NumElements() / N}); + + functor::ScatterFunctor functor; + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } } } }; @@ -195,16 +211,31 @@ class ScatterUpdateOp : public OpKernel { auto indices_flat = indices_host.flat(); auto params_flat = params.flat_outer_dims(); - auto updates_flat = updates.shaped({N, updates.NumElements() / N}); - - functor::ScatterFunctorSYCL functor; - const Index bad_i = functor(c, c->template eigen_device(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); + + if (TensorShapeUtils::IsScalar(updates.shape())) { + const auto update = updates.scalar(); + + functor::ScatterScalarFunctorSYCL functor; + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } else { + auto updates_flat = + updates.shaped({N, updates.NumElements() / N}); + + functor::ScatterFunctorSYCL functor; + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } } } }; @@ -221,54 +252,71 @@ class ScatterUpdateOp : public OpKernel { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ +#define REGISTER_SCATTER_ARITHMETIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); +#define REGISTER_SCATTER_MINMAX(type, dev) \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX); + #define REGISTER_SCATTER_UPDATE(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \ scatter_op::UpdateOp::ASSIGN); // Registers CPU kernels. -#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, CPU); +#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, CPU); + +#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, GPU); +#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, GPU); + +#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA // Registers GPU kernels. #if TENSORFLOW_USE_SYCL -#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, SYCL); +#define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \ + REGISTER_SCATTER_ARITHMETIC(type, SYCL); + +#define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL); #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL); -#undef REGISTER_SCATTER_ARITHEMTIC_SYCL +#undef REGISTER_SCATTER_ARITHMETIC_SYCL +#undef REGISTER_SCATTER_MINMAX_SYCL #undef REGISTER_SCATTER_UPDATE_SYCL #endif // TENSORFLOW_USE_SYCL -#undef REGISTER_SCATTER_ARITHEMTIC -#undef REGISTER_SCATTER_ARITHEMTIC_CPU -#undef REGISTER_SCATTER_ARITHEMTIC_GPU +#undef REGISTER_SCATTER_ARITHMETIC +#undef REGISTER_SCATTER_ARITHMETIC_CPU +#undef REGISTER_SCATTER_ARITHMETIC_GPU +#undef REGISTER_SCATTER_MINMAX +#undef REGISTER_SCATTER_MINMAX_CPU +#undef REGISTER_SCATTER_MINMAX_GPU #undef REGISTER_SCATTER_UPDATE #undef REGISTER_SCATTER_UPDATE_CPU #undef REGISTER_SCATTER_UPDATE_GPU diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc index 0b43704846..0df329310f 100644 --- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc @@ -24,15 +24,18 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; // Instantiates functor specializations for GPU. -#define DEFINE_GPU_SPECS_OP(T, Index, op) \ - template struct functor::ScatterFunctor; +#define DEFINE_GPU_SPECS_OP(T, Index, op) \ + template struct functor::ScatterFunctor; \ + template struct functor::ScatterScalarFunctor; #define DEFINE_GPU_SPECS_INDEX(T, Index) \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DEFINE_GPU_SPECS(T) \ DEFINE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index 0b8645a2ae..5b3537b94c 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -185,7 +185,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { Status s = RunOpKernel(); EXPECT_TRUE(StringPiece(s.ToString()) .contains("Must have updates.shape = indices.shape + " - "params.shape[1:], got ")) + "params.shape[1:] or updates.shape = [], got ")) << s; } @@ -202,7 +202,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { Status s = RunOpKernel(); EXPECT_TRUE(StringPiece(s.ToString()) .contains("Must have updates.shape = indices.shape + " - "params.shape[1:], got ")) + "params.shape[1:] or updates.shape = [], got ")) << s; } @@ -219,7 +219,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { Status s = RunOpKernel(); EXPECT_TRUE(StringPiece(s.ToString()) .contains("Must have updates.shape = indices.shape + " - "params.shape[1:], got ")) + "params.shape[1:] or updates.shape = [], got ")) << s; } @@ -300,6 +300,20 @@ static void BM_ScatterDivInt64(int iters, int embedding_size) { BM_ScatterHelper(iters, embedding_size, "ScatterDiv"); } +static void BM_ScatterMinInt32(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterMin"); +} +static void BM_ScatterMinInt64(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterMin"); +} + +static void BM_ScatterMaxInt32(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterMax"); +} +static void BM_ScatterMaxInt64(int iters, int embedding_size) { + BM_ScatterHelper(iters, embedding_size, "ScatterMax"); +} + BENCHMARK(BM_ScatterUpdateInt32) ->Arg(1) ->Arg(10) @@ -332,5 +346,11 @@ BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMinInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMinInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + +BENCHMARK(BM_ScatterMaxInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMaxInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 0d8cf78cc2..3d0a6c2157 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -167,27 +167,75 @@ REGISTER_OP("ResourceGather") return Status::OK(); }); +namespace { + +Status ResourceScatterUpdateShape(InferenceContext* c) { + ShapeAndType handle_shape_and_type; + TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type)); + ShapeHandle var_shape = handle_shape_and_type.shape; + ShapeHandle indices_shape = c->input(1); + + ShapeHandle unused_updates_shape; + ShapeHandle concat; + ShapeHandle var_subshape; + TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); + TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); + TF_RETURN_IF_ERROR( + InferenceContext::Rank(c->input(2)) == 0 + ? Status::OK() + : c->Merge(c->input(2), concat, &unused_updates_shape)); + return Status::OK(); +} + +} // namespace + REGISTER_OP("ResourceScatterAdd") .Input("resource: resource") .Input("indices: Tindices") .Input("updates: dtype") .Attr("dtype: numbertype") .Attr("Tindices: {int32, int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeAndType handle_shape_and_type; - TF_RETURN_IF_ERROR( - ValidateVariableResourceHandle(c, &handle_shape_and_type)); - ShapeHandle var_shape = handle_shape_and_type.shape; - ShapeHandle indices_shape = c->input(1); + .SetShapeFn(ResourceScatterUpdateShape); - ShapeHandle unused_updates_shape; - ShapeHandle concat; - ShapeHandle var_subshape; - TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); - TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); - TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); - return Status::OK(); - }); +REGISTER_OP("ResourceScatterSub") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterMul") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterDiv") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterMin") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterMax") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); REGISTER_OP("ResourceScatterUpdate") .Input("resource: resource") @@ -195,21 +243,7 @@ REGISTER_OP("ResourceScatterUpdate") .Input("updates: dtype") .Attr("dtype: type") .Attr("Tindices: {int32, int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeAndType handle_shape_and_type; - TF_RETURN_IF_ERROR( - ValidateVariableResourceHandle(c, &handle_shape_and_type)); - ShapeHandle var_shape = handle_shape_and_type.shape; - ShapeHandle indices_shape = c->input(1); - - ShapeHandle unused_updates_shape; - ShapeHandle concat; - ShapeHandle var_subshape; - TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); - TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); - TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); - return Status::OK(); - }); + .SetShapeFn(ResourceScatterUpdateShape); REGISTER_OP("MutexV2") .Attr("container: string = ''") diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 7a524b60c0..664f52452e 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -122,7 +122,10 @@ Status ScatterUpdateShape(InferenceContext* c) { ShapeHandle var_subshape; TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); - TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); + TF_RETURN_IF_ERROR( + InferenceContext::Rank(c->input(2)) == 0 + ? Status::OK() + : c->Merge(c->input(2), concat, &unused_updates_shape)); c->set_output(0, var_shape); return Status::OK(); @@ -180,6 +183,26 @@ REGISTER_OP("ScatterDiv") .Attr("use_locking: bool = false") .SetShapeFn(ScatterUpdateShape); +REGISTER_OP("ScatterMin") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: {half, bfloat16, float, double, int32, int64}") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape); + +REGISTER_OP("ScatterMax") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: {half, bfloat16, float, double, int32, int64}") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape); + REGISTER_OP("ScatterNdUpdate") .Input("ref: Ref(T)") .Input("indices: Tindices") diff --git a/tensorflow/docs_src/api_guides/python/state_ops.md b/tensorflow/docs_src/api_guides/python/state_ops.md index 0d612ee0c7..ec2d877386 100644 --- a/tensorflow/docs_src/api_guides/python/state_ops.md +++ b/tensorflow/docs_src/api_guides/python/state_ops.md @@ -83,6 +83,8 @@ automatically by the optimizers in most cases. * @{tf.scatter_sub} * @{tf.scatter_mul} * @{tf.scatter_div} +* @{tf.scatter_min} +* @{tf.scatter_max} * @{tf.scatter_nd_update} * @{tf.scatter_nd_add} * @{tf.scatter_nd_sub} diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 563eeff2a6..742564f9bf 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -185,6 +185,204 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterSub(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_sub(handle, [0], + constant_op.constant( + [[2]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[-1]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMul(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_mul(handle, [0], + constant_op.constant( + [[5]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[5]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterDiv(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_div(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[2]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMin(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_min(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMax(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_max(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[6]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterAddScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_add(handle, [0], + constant_op.constant( + 2, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterSubScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_sub(handle, [0], + constant_op.constant( + 2, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[-1]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMulScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_mul(handle, [0], + constant_op.constant( + 5, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[5]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterDivScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_div(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[2]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMinScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_min(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMaxScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_max(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[6]]) + def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.string, shape=[1, 1]) @@ -196,6 +394,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b")) + def testScatterUpdateStringScalar(self): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.string, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [["a"]], + dtype=dtypes.string))) + self.evaluate( + resource_variable_ops.resource_scatter_update(handle, [0], + constant_op.constant( + "b", + dtype=dtypes.string))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) + self.assertEqual( + compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b")) + # TODO(alive): get this to work in Eager mode. def testGPU(self): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index 7cdf11d884..c70a4ffce7 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -38,38 +38,100 @@ def _NumpyAdd(ref, indices, updates): ref[indx] += updates[i] +def _NumpyAddScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] += update + + def _NumpySub(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] -= updates[i] +def _NumpySubScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] -= update + + def _NumpyMul(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] *= updates[i] +def _NumpyMulScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] *= update + + def _NumpyDiv(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] /= updates[i] +def _NumpyDivScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] /= update + + +def _NumpyMin(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] = np.minimum(ref[indx], updates[i]) + + +def _NumpyMinScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = np.minimum(ref[indx], update) + + +def _NumpyMax(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] = np.maximum(ref[indx], updates[i]) + + +def _NumpyMaxScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = np.maximum(ref[indx], update) + + def _NumpyUpdate(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] = updates[i] +def _NumpyUpdateScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = update + + _TF_OPS_TO_NUMPY = { state_ops.scatter_update: _NumpyUpdate, state_ops.scatter_add: _NumpyAdd, state_ops.scatter_sub: _NumpySub, state_ops.scatter_mul: _NumpyMul, state_ops.scatter_div: _NumpyDiv, + state_ops.scatter_min: _NumpyMin, + state_ops.scatter_max: _NumpyMax, +} + +_TF_OPS_TO_NUMPY_SCALAR = { + state_ops.scatter_update: _NumpyUpdateScalar, + state_ops.scatter_add: _NumpyAddScalar, + state_ops.scatter_sub: _NumpySubScalar, + state_ops.scatter_mul: _NumpyMulScalar, + state_ops.scatter_div: _NumpyDivScalar, + state_ops.scatter_min: _NumpyMinScalar, + state_ops.scatter_max: _NumpyMaxScalar, } class ScatterTest(test.TestCase): - def _VariableRankTest(self, tf_scatter, vtype, itype, repeat_indices=False): + def _VariableRankTest(self, + tf_scatter, + vtype, + itype, + repeat_indices=False, + updates_are_scalar=False): np.random.seed(8) with self.test_session(use_gpu=True): for indices_shape in (), (2,), (3, 7), (3, 4, 7): @@ -89,8 +151,11 @@ class ScatterTest(test.TestCase): indices[np.random.randint(size // 2)]) np.random.shuffle(indices) indices = indices.reshape(indices_shape) - updates = _AsType( - np.random.randn(*(indices_shape + extra_shape)), vtype) + if updates_are_scalar: + updates = _AsType(np.random.randn(), vtype) + else: + updates = _AsType( + np.random.randn(*(indices_shape + extra_shape)), vtype) # Clips small values to avoid division by zero. def clip_small_values(x): @@ -101,7 +166,10 @@ class ScatterTest(test.TestCase): # Scatter via numpy new = old.copy() - np_scatter = _TF_OPS_TO_NUMPY[tf_scatter] + if updates_are_scalar: + np_scatter = _TF_OPS_TO_NUMPY_SCALAR[tf_scatter] + else: + np_scatter = _TF_OPS_TO_NUMPY[tf_scatter] np_scatter(new, indices, updates) # Scatter via tensorflow ref = variables.Variable(old) @@ -109,25 +177,35 @@ class ScatterTest(test.TestCase): tf_scatter(ref, indices, updates).eval() self.assertAllClose(ref.eval(), new) - def _VariableRankTests(self, tf_scatter, repeat_indices=False): + def _VariableRankTests(self, + tf_scatter, + repeat_indices=False, + updates_are_scalar=False): for vtype in (np.float32, np.float64): for itype in (np.int32, np.int64): - self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices) + self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices, + updates_are_scalar) def testVariableRankUpdate(self): - self._VariableRankTests(state_ops.scatter_update) + self._VariableRankTests(state_ops.scatter_update, False) def testVariableRankAdd(self): - self._VariableRankTests(state_ops.scatter_add) + self._VariableRankTests(state_ops.scatter_add, False) def testVariableRankSub(self): - self._VariableRankTests(state_ops.scatter_sub) + self._VariableRankTests(state_ops.scatter_sub, False) def testVariableRankMul(self): - self._VariableRankTests(state_ops.scatter_mul) + self._VariableRankTests(state_ops.scatter_mul, False) def testVariableRankDiv(self): - self._VariableRankTests(state_ops.scatter_div) + self._VariableRankTests(state_ops.scatter_div, False) + + def testVariableRankMin(self): + self._VariableRankTests(state_ops.scatter_min, False) + + def testVariableRankMax(self): + self._VariableRankTests(state_ops.scatter_max, False) def testRepeatIndicesAdd(self): self._VariableRankTests(state_ops.scatter_add, True) @@ -141,6 +219,51 @@ class ScatterTest(test.TestCase): def testRepeatIndicesDiv(self): self._VariableRankTests(state_ops.scatter_div, True) + def testRepeatIndicesMin(self): + self._VariableRankTests(state_ops.scatter_min, True) + + def testRepeatIndicesMax(self): + self._VariableRankTests(state_ops.scatter_max, True) + + def testVariableRankUpdateScalar(self): + self._VariableRankTests(state_ops.scatter_update, False, True) + + def testVariableRankAddScalar(self): + self._VariableRankTests(state_ops.scatter_add, False, True) + + def testVariableRankSubScalar(self): + self._VariableRankTests(state_ops.scatter_sub, False, True) + + def testVariableRankMulScalar(self): + self._VariableRankTests(state_ops.scatter_mul, False, True) + + def testVariableRankDivScalar(self): + self._VariableRankTests(state_ops.scatter_div, False, True) + + def testVariableRankMinScalar(self): + self._VariableRankTests(state_ops.scatter_min, False, True) + + def testVariableRankMaxScalar(self): + self._VariableRankTests(state_ops.scatter_max, False, True) + + def testRepeatIndicesAddScalar(self): + self._VariableRankTests(state_ops.scatter_add, True, True) + + def testRepeatIndicesSubScalar(self): + self._VariableRankTests(state_ops.scatter_sub, True, True) + + def testRepeatIndicesMulScalar(self): + self._VariableRankTests(state_ops.scatter_mul, True, True) + + def testRepeatIndicesDivScalar(self): + self._VariableRankTests(state_ops.scatter_div, True, True) + + def testRepeatIndicesMinScalar(self): + self._VariableRankTests(state_ops.scatter_min, True, True) + + def testRepeatIndicesMaxScalar(self): + self._VariableRankTests(state_ops.scatter_max, True, True) + def testBooleanScatterUpdate(self): if not test.is_gpu_available(): with self.test_session(use_gpu=False) as session: diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 230b7ef937..e90ff0746a 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -80,6 +80,8 @@ from tensorflow.python.ops.state_ops import scatter_add from tensorflow.python.ops.state_ops import scatter_div from tensorflow.python.ops.state_ops import scatter_mul from tensorflow.python.ops.state_ops import scatter_sub +from tensorflow.python.ops.state_ops import scatter_min +from tensorflow.python.ops.state_ops import scatter_max from tensorflow.python.ops.state_ops import scatter_update from tensorflow.python.ops.state_ops import scatter_nd_add from tensorflow.python.ops.state_ops import scatter_nd_sub diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index c3ad5831b4..01fc3182bc 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -63,6 +63,8 @@ @@scatter_nd_update @@scatter_sub @@scatter_update +@@scatter_min +@@scatter_max @@sparse_mask @@tables_initializer @@trainable_variables diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 55b82dd765..937044aece 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1688,6 +1688,14 @@ tf_module { name: "scatter_div" argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } + member_method { + name: "scatter_max" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_min" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } member_method { name: "scatter_mul" argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " -- cgit v1.2.3