aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/core/kernels/BUILD21
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc547
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.h52
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_test.cc320
-rw-r--r--tensorflow/core/ops/array_ops.cc77
-rw-r--r--tensorflow/core/ops/state_ops.cc235
-rw-r--r--tensorflow/python/kernel_tests/BUILD7
-rw-r--r--tensorflow/python/kernel_tests/gather_nd_op_test.py72
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py383
-rw-r--r--tensorflow/python/ops/array_grad.py15
-rw-r--r--tensorflow/python/ops/array_ops.py48
-rw-r--r--tensorflow/python/ops/standard_ops.py5
-rw-r--r--tensorflow/python/ops/state_grad.py13
-rw-r--r--tensorflow/python/ops/state_ops.py36
15 files changed, 1821 insertions, 11 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index ed5d6539b3..ea1612201e 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -184,3 +184,4 @@ tensorflow/core/ops/control_flow_ops.cc
tensorflow/core/ops/candidate_sampling_ops.cc
tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/array_grad.cc
+tensorflow/core/ops/scatter_nd_op.cc
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 34954f0066..0a7a42fde6 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2025,11 +2025,13 @@ tf_kernel_libraries(
"count_up_to_op",
"dense_update_ops",
"scatter_op",
+ "scatter_nd_op",
"variable_ops",
],
deps = [
":assign_op",
":bounds_check",
+ ":fill_functor",
":scatter_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -2043,6 +2045,7 @@ tf_cc_test(
size = "small",
srcs = ["scatter_op_test.cc"],
deps = [
+ ":fill_functor",
":ops_testutil",
":ops_util",
":scatter_op",
@@ -2055,6 +2058,23 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "scatter_nd_op_test",
+ size = "small",
+ srcs = ["scatter_nd_op_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ ":scatter_nd_op",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_kernel_libraries(
name = "string",
prefixes = [
@@ -2497,6 +2517,7 @@ filegroup(
"debug_ops.*",
# Ops excluded because they do not build correctly for Android.
# See b/29213790
+ "scatter_nd_op.*",
"sparse_matmul_op.*",
],
),
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
new file mode 100644
index 0000000000..61098c7802
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -0,0 +1,547 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/state_ops.cc.
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/scatter_nd_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/fill_functor.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/util.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// Check whether updates.shape = indices.shape[0] + params.shape[IXDIM:]
+static bool ValidUpdateShape(const TensorShape& params_shape,
+ const Tensor& indices, const Tensor& updates) {
+ int64 indices_nd = 1;
+ if (indices.dims() > 1) {
+ indices_nd = indices.dim_size(1);
+ }
+ for (int d = indices_nd; d < params_shape.dims(); d++) {
+ if (updates.dim_size(d - indices_nd + 1) != params_shape.dim_size(d)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+template <typename Index>
+static void PrepareAndValidateInputs(OpKernelContext* c,
+ const TensorShape& params_shape,
+ const Tensor& indices,
+ const Tensor& updates, int64* indices_nd,
+ Index* num_updates, Index* slice_size) {
+ const TensorShape& indices_shape(indices.shape());
+ const TensorShape& updates_shape(updates.shape());
+
+ OP_REQUIRES(
+ c, TensorShapeUtils::IsVectorOrHigher(params_shape),
+ errors::InvalidArgument("Output must be at least 1-D, ", "got shape: ",
+ params_shape.DebugString()));
+
+ OP_REQUIRES(c, params_shape.num_elements() >= 0 ||
+ (indices.NumElements() == 0 && updates.NumElements() == 0),
+ errors::InvalidArgument(
+ "Indices and updates specified for empty output", " shape"));
+
+ OP_REQUIRES(c, updates.dim_size(0) == indices.dim_size(0),
+ errors::InvalidArgument(
+ "The outermost dimension of updates and indices ",
+ "must match. Got indices.shape ", indices_shape.DebugString(),
+ ", updates.shape ", updates_shape.DebugString()));
+ OP_REQUIRES(
+ c, ValidUpdateShape(params_shape, indices, updates),
+ errors::InvalidArgument(
+ "Must have updates.shape = indices.shape[0] + params_shape[IXDIM:], ",
+ "got updates.shape ", updates_shape.DebugString(), ", indices.shape ",
+ indices_shape.DebugString(), ", params_shape ",
+ params_shape.DebugString()));
+ // Check that we have enough index space
+ const int64 N_big = indices.NumElements();
+ OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
+ errors::InvalidArgument(
+ "indices has too many elements for ",
+ DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
+ N_big, " > ", std::numeric_limits<Index>::max()));
+ OP_REQUIRES(
+ c, params_shape.dim_size(0) <= std::numeric_limits<Index>::max(),
+ errors::InvalidArgument("params_shape[0] too large for ",
+ DataTypeString(DataTypeToEnum<Index>::v()),
+ " indexing: ", params_shape.dim_size(0), " > ",
+ std::numeric_limits<Index>::max()));
+
+ // Calculate the number of dimensions in indices
+ *indices_nd = 1;
+ if (indices_shape.dims() > 1) {
+ *indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
+ }
+
+ // Calculate the number of elements that make up each slice of our updated
+ // tensor. This allows us to work with flattened tensors and copy over whole
+ // slices at a time.
+ Index total_nd = params_shape.dims();
+
+ int64 slice_size_big = 1;
+ for (int64 i = *indices_nd; i < total_nd; ++i) {
+ slice_size_big *= params_shape.dim_size(i);
+ }
+
+ OP_REQUIRES(c, slice_size_big <= std::numeric_limits<Index>::max(),
+ errors::InvalidArgument("slice size is too large for indexing: ",
+ slice_size_big, " > ",
+ std::numeric_limits<Index>::max()));
+
+ *slice_size = static_cast<Index>(slice_size_big);
+
+ const int64 safe_indices_nd = (*indices_nd < 1) ? 1 : *indices_nd;
+ *num_updates = indices_shape.num_elements() / safe_indices_nd;
+}
+
+template <typename Device, typename T, typename Index>
+class ScatterNdOp : public OpKernel {
+ public:
+ explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ const DataType index_t = DataTypeToEnum<Index>::v();
+ OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& indices = c->input(0);
+ const Tensor& updates = c->input(1);
+ const Tensor& shape_input = c->input(2);
+
+ OP_REQUIRES(c, shape_input.dims() == 1,
+ errors::InvalidArgument("Shape must be a vector"));
+ auto vec = shape_input.flat<Index>();
+ TensorShape shape;
+ TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape);
+
+ int64 indices_nd;
+ Index num_updates;
+ Index slice_size;
+ PrepareAndValidateInputs<Index>(c, shape, indices, updates, &indices_nd,
+ &num_updates, &slice_size);
+ if (!c->status().ok()) return;
+
+ Tensor scratch;
+ OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
+
+ auto scratch_scalar = scratch.scalar<Index>();
+ auto indices_flat = indices.flat_inner_dims<Index>();
+ auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
+
+ Index bad_i = -1;
+ switch (indices_nd) {
+#define PARAMS_CASE(IXDIM) \
+ case IXDIM: { \
+ Tensor* out = nullptr; \
+ OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \
+ functor::SetZeroFunctor<Device, T> fill; \
+ fill(c->eigen_device<Device>(), out->flat<T>()); \
+ if (shape.num_elements() > 0) { \
+ auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \
+ functor::ScatterNdFunctor<Device, T, Index, \
+ scatter_nd_op::UpdateOp::ADD, (IXDIM)> \
+ functor; \
+ bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
+ output_flat, indices_flat, updates_flat, output_flat); \
+ } \
+ } break
+ PARAMS_CASE(0);
+ PARAMS_CASE(1);
+ PARAMS_CASE(2);
+ PARAMS_CASE(3);
+ PARAMS_CASE(4);
+ PARAMS_CASE(5);
+#undef PARAMS_CASE
+ default:
+ OP_REQUIRES(c, false,
+ errors::InvalidArgument(
+ "Only indices.shape[-1] values between 0 and 5 "
+ "are currently supported. Requested rank: ",
+ indices_nd));
+ }
+ OP_REQUIRES(
+ c, bad_i < 0,
+ errors::InvalidArgument(
+ "Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
+ " = [", str_util::Join(gtl::ArraySlice<Index>(
+ &indices_flat(bad_i, 0), indices_nd),
+ ", "),
+ "] does not index into ", shape.DebugString()));
+ }
+};
+
+template <typename Device, typename T, typename Index,
+ scatter_nd_op::UpdateOp op>
+class ScatterNdUpdateOp : public OpKernel {
+ public:
+ explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ const DataType dt_ref = DataTypeToEnum<T>::ref();
+ const DataType index_t = DataTypeToEnum<Index>::v();
+ OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
+ OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ if (use_exclusive_lock_) {
+ // Hold mutex while we apply updates
+ mutex_lock l(*c->input_ref_mutex(0));
+ DoCompute(c);
+ } else {
+ DoCompute(c);
+ }
+ }
+
+ private:
+ bool use_exclusive_lock_;
+
+ void DoCompute(OpKernelContext* c) {
+ Tensor params = c->mutable_input(0, use_exclusive_lock_);
+ const Tensor& indices = c->input(1);
+ const Tensor& updates = c->input(2);
+ const TensorShape& params_shape(params.shape());
+
+ int64 indices_nd;
+ Index num_updates;
+ Index slice_size;
+
+ OP_REQUIRES(c, params.IsInitialized(),
+ errors::FailedPrecondition("Null ref for params"));
+ PrepareAndValidateInputs<Index>(c, params_shape, indices, updates,
+ &indices_nd, &num_updates, &slice_size);
+ if (!c->status().ok()) return;
+
+ Tensor scratch;
+ OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
+
+ auto scratch_scalar = scratch.scalar<Index>();
+ auto indices_flat = indices.flat_inner_dims<Index>();
+ auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
+
+ Index bad_i = -1;
+ c->forward_ref_input_to_ref_output(0, 0);
+ switch (indices_nd) {
+#define PARAMS_CASE(IXDIM) \
+ case IXDIM: { \
+ auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \
+ functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
+ bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
+ params_flat, indices_flat, updates_flat, params_flat); \
+ } break
+ PARAMS_CASE(0);
+ PARAMS_CASE(1);
+ PARAMS_CASE(2);
+ PARAMS_CASE(3);
+ PARAMS_CASE(4);
+ PARAMS_CASE(5);
+#undef PARAMS_CASE
+ default:
+ OP_REQUIRES(c, false,
+ errors::InvalidArgument(
+ "Only indices.shape[-1] values between 1 and 5 "
+ "are currently supported. Requested rank: ",
+ indices_nd));
+ }
+ OP_REQUIRES(
+ c, bad_i < 0,
+ errors::InvalidArgument(
+ "Invalid indices: ", SliceDebugString(indices.shape(), bad_i),
+ " = [", str_util::Join(gtl::ArraySlice<Index>(
+ &indices_flat(bad_i, 0), indices_nd),
+ ", "),
+ "] is not in [0, ", params.dim_size(0), ")"));
+ }
+};
+
+// Specialization of ScatterNdSliceGenerator to CPU
+namespace generator {
+
+template <typename T, typename Index, scatter_nd_op::UpdateOp op>
+class UpdateExecutor {
+ public:
+ static void Update(T* input, const T* updates, T* output, Index slice_size);
+};
+
+template <typename T, typename Index>
+class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ASSIGN> {
+ public:
+ static void Update(T* /* unused */, const T* updates, T* output,
+ Index slice_size) {
+ std::copy_n(updates, slice_size, output);
+ }
+};
+
+template <typename T, typename Index>
+class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ADD> {
+ public:
+ static void Update(T* input, const T* updates, T* output, Index slice_size) {
+ std::transform(input, input + slice_size, updates, output, std::plus<T>());
+ }
+};
+
+template <typename T, typename Index>
+class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::SUB> {
+ public:
+ static void Update(T* input, const T* updates, T* output, Index slice_size) {
+ std::transform(input, input + slice_size, updates, output, std::minus<T>());
+ }
+};
+
+template <typename T, typename Index>
+class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::MUL> {
+ public:
+ static void Update(T* input, const T* updates, T* output, Index slice_size) {
+ std::transform(input, input + slice_size, updates, output,
+ std::multiplies<T>());
+ }
+};
+
+template <typename T, typename Index>
+class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::DIV> {
+ public:
+ static void Update(T* input, const T* updates, T* output, Index slice_size) {
+ std::transform(input, input + slice_size, updates, output,
+ std::divides<T>());
+ }
+};
+
+template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
+class ScatterNdSliceGenerator {
+ public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ScatterNdSliceGenerator(
+ const Index slice_size, typename TTypes<T, IXDIM + 1>::Tensor Tparams,
+ typename TTypes<Index, 2>::ConstTensor Tindices,
+ typename TTypes<T, 2>::ConstTensor Tupdates,
+ typename TTypes<T, IXDIM + 1>::Tensor Toutput,
+ std::atomic<Index>* error_loc)
+ : slice_size_(slice_size),
+ Tparams_(Tparams),
+ Tindices_(Tindices),
+ Tupdates_(Tupdates),
+ Toutput_(Toutput),
+ error_loc_(error_loc) {}
+
+ EIGEN_DEVICE_FUNC bool GenerateIndices(
+ const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
+ (*ix)[IXDIM] = 0;
+ bool out_of_bounds = false;
+ for (int i = 0; i < IXDIM; ++i) {
+ const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
+ (*ix)[i] = ix_i;
+ out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
+ }
+ return out_of_bounds;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
+ operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
+ auto loc = loc_array[0];
+ Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix_params;
+ Eigen::array<Eigen::DenseIndex, 2> ix_updates;
+ ix_updates[0] = loc;
+ ix_updates[1] = 0;
+ const bool out_of_bounds = GenerateIndices(loc, &ix_params);
+ if (TF_PREDICT_FALSE(out_of_bounds)) {
+ error_loc_->store(loc);
+ } else {
+ UpdateExecutor<T, Index, op>::Update(&Tparams_(ix_params),
+ &Tupdates_(ix_updates),
+ &Toutput_(ix_params), slice_size_);
+ }
+ return static_cast<int32>(0); // Return something...
+ }
+
+ protected:
+ const Index slice_size_;
+ mutable typename TTypes<T, IXDIM + 1>::Tensor Tparams_;
+ const typename TTypes<Index, 2>::ConstTensor Tindices_;
+ const typename TTypes<T, 2>::ConstTensor Tupdates_;
+ mutable typename TTypes<T, IXDIM + 1>::Tensor Toutput_;
+ std::atomic<Index>* error_loc_;
+};
+
+} // namespace generator
+
+namespace functor {
+// Implementation of update functor for CPU.
+template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
+struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> {
+ Index operator()(const CPUDevice& d, const Index slice_size,
+ typename TTypes<Index>::Scalar Tscratch,
+ typename TTypes<T, IXDIM + 1>::Tensor Tparams,
+ typename TTypes<Index, 2>::ConstTensor Tindices,
+ typename TTypes<T, 2>::ConstTensor Tupdates,
+ typename TTypes<T, IXDIM + 1>::Tensor Toutput) {
+ std::atomic<Index> error_loc(-1);
+
+ const Eigen::DenseIndex batch_size = Tindices.dimension(0);
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
+ Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
+#else
+ Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
+ Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
+ broadcast_dims.set(0, batch_size);
+#endif
+
+ generator::ScatterNdSliceGenerator<T, Index, op, IXDIM> generator(
+ slice_size, Tparams, Tindices, Tupdates, Toutput, &error_loc);
+ Tscratch.device(d) = Tscratch.reshape(reshape_dims)
+ .broadcast(broadcast_dims)
+ .generate(generator)
+ .sum();
+
+ // error_loc() returns -1 if there's no out-of-bounds index,
+ // otherwise it returns the location of an OOB index in Tindices.
+ return error_loc.load();
+ }
+};
+} // namespace functor
+
+#define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
+ REGISTER_KERNEL_BUILDER(Name(name) \
+ .Device(DEVICE_##dev) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ScatterNdOp<dev##Device, type, index_type>)
+
+#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
+ op) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(name) \
+ .Device(DEVICE_##dev) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ScatterNdUpdateOp<dev##Device, type, index_type, op>)
+
+#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
+ REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
+ REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
+
+#define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
+
+#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
+ scatter_nd_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
+ scatter_nd_op::UpdateOp::SUB); \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
+ scatter_nd_op::UpdateOp::MUL); \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
+ scatter_nd_op::UpdateOp::DIV);
+
+#define REGISTER_SCATTER_ND(type, dev) \
+ REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
+
+#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
+ scatter_nd_op::UpdateOp::ASSIGN);
+
+// Registers CPU kernels.
+#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
+ REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
+
+#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
+ REGISTER_SCATTER_ND_UPDATE(type, CPU);
+
+#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
+
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
+TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
+TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
+
+// Registers GPU kernels.
+#if GOOGLE_CUDA
+#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
+ REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
+
+#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
+ REGISTER_SCATTER_ND_UPDATE(type, GPU);
+
+// TODO(simister): Re-enable when GPU support is working.
+// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_GPU);
+// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_GPU);
+
+#endif // GOOGLE_CUDA
+
+#undef REGISTER_SCATTER_ND_ADD
+#undef REGISTER_SCATTER_ND_ADD_SUB
+#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
+#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
+#undef REGISTER_SCATTER_ND_UPDATE
+#undef REGISTER_SCATTER_ND_UPDATE_CPU
+#undef REGISTER_SCATTER_ND_UPDATE_GPU
+#undef REGISTER_SCATTER_ND_KERNEL
+#undef REGISTER_SCATTER_ND_KERNEL_INDEX
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \
+ template <> \
+ Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T, IXDIM>::Tensor params, \
+ typename TTypes<Index, 2>::ConstTensor indices, \
+ typename TTypes<T, 2>::ConstTensor updates); \
+ extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>;
+
+#define DECLARE_GPU_SPECS_OPS(T, Index, op) \
+ DECLARE_GPU_SPECS_OP(T, Index, op, 0); \
+ DECLARE_GPU_SPECS_OP(T, Index, op, 1); \
+ DECLARE_GPU_SPECS_OP(T, Index, op, 2); \
+ DECLARE_GPU_SPECS_OP(T, Index, op, 3); \
+ DECLARE_GPU_SPECS_OP(T, Index, op, 4); \
+ DECLARE_GPU_SPECS_OP(T, Index, op, 5)
+
+#define DECLARE_GPU_SPECS_INDEX(T, Index) \
+ DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
+ DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \
+ DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \
+ DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \
+ DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV);
+
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPECS_INDEX(T, int32); \
+ DECLARE_GPU_SPECS_INDEX(T, int64);
+
+// TODO(simister): Re-enable when GPU support is working.
+// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
+
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_SPECS_INDEX
+#undef DECLARE_GPU_SPECS_OPS
+#undef DECLARE_GPU_SPECS_OP
+
+} // namespace functor
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_nd_op.h b/tensorflow/core/kernels/scatter_nd_op.h
new file mode 100644
index 0000000000..e4c8e7ed9f
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_op.h
@@ -0,0 +1,52 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
+#define TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
+
+// Functor definitions for ScatterND ops, must be compilable by nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace scatter_nd_op {
+
+enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
+
+} // namespace scatter_nd_op
+
+namespace functor {
+
+// Functor used by ScatterOp to do the computations.
+template <typename Device, typename T, typename Index,
+ scatter_nd_op::UpdateOp op, int IXDIM>
+struct ScatterNdFunctor {
+ // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<Index>::Scalar Tscratch,
+ typename TTypes<T, IXDIM>::Tensor params,
+ typename TTypes<T, 2>::ConstTensor indices,
+ typename TTypes<T, 2>::ConstTensor updates,
+ typename TTypes<T, IXDIM + 1>::Tensor Toutput);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SCATTER_ND_OP_H_
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
new file mode 100644
index 0000000000..d6743a6867
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -0,0 +1,320 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class ScatterNdUpdateOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType variable_ref_type, DataType index_type) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
+ .Input(FakeInput(variable_ref_type))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(RemoveRefType(variable_ref_type)))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
+ MakeOp(DT_STRING_REF, DT_INT32);
+ AddInputFromArray<string>(TensorShape({1}), {"Brain"});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
+ TF_ASSERT_OK(RunOpKernel());
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_STRING, TensorShape({1}));
+ test::FillValues<string>(&expected, {"TensorFlow"});
+ test::ExpectTensorEqual<string>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
+ MakeOp(DT_BOOL_REF, DT_INT32);
+ AddInputFromArray<bool>(TensorShape({1}), {false});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<bool>(TensorShape({1}), {true});
+ TF_ASSERT_OK(RunOpKernel());
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
+ test::FillValues<bool>(&expected, {true});
+ test::ExpectTensorEqual<bool>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
+ 10002, 0, 0, 0, 777, 778, 779});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_Two64) {
+ MakeOp(DT_FLOAT_REF, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int64>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
+ 10002, 0, 0, 0, 777, 778, 779});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+/*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({0}), {});
+ AddInputFromArray<int32>(TensorShape({0}), {});
+ AddInputFromArray<float>(TensorShape({0}), {});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Output must not have 0 elements, got shape: "))
+ << s;
+}*/
+
+TEST_F(ScatterNdUpdateOpTest, Simple_ZeroD) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({1}), {3});
+ AddInputFromArray<float>(TensorShape({1}), {101});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_OneD) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, HigherRank) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
+ AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
+ test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 99});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Invalid indices: [2,0] = [99] is not in [0, 5)"))
+ << s;
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({1, 3, 1}), {0, 4, 99});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("The outermost dimension of updates and indices "
+ "must match. Got indices.shape [1,3,1], "
+ "updates.shape [3,3]"))
+ << s;
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(
+ TensorShape({3, 4}),
+ {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Must have updates.shape = indices.shape[0] + "
+ "params_shape[IXDIM:], got"))
+
+ << s;
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({2, 3}),
+ {100, 101, 102, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("The outermost dimension of updates and indices "
+ "must match. Got "))
+ << s;
+}
+
+class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
+ public:
+ virtual void TestBody() {}
+ void MakeBenchmarkOp(const char* op, DataType index_type) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", op)
+ .Input(FakeInput(DT_FLOAT_REF))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+ }
+};
+
+template <typename Index>
+static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
+ testing::StopTiming();
+ const int kRows = 10000000 / embedding_size;
+ std::vector<float> values;
+ values.reserve(kRows);
+ for (int i = 0; i < kRows * embedding_size; i++) {
+ values.push_back(i);
+ }
+ const int kNumUpdates = 1000;
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<Index> indices;
+ std::vector<float> updates;
+ for (int i = 0; i < kNumUpdates; i++) {
+ indices.push_back(rnd.Uniform(kRows));
+ for (int j = 0; j < embedding_size; j++) {
+ updates.push_back(i * 10 + j);
+ }
+ }
+
+ ScatterNdUpdateBM bm;
+ bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
+ bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
+ bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
+ bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
+ updates);
+ testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
+ iters);
+ testing::StartTiming();
+ while (iters-- > 0) {
+ Status s = bm.RunOpKernel();
+ }
+ testing::StopTiming();
+}
+
+static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
+}
+static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
+}
+
+static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
+}
+static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
+}
+
+BENCHMARK(BM_ScatterNdUpdateInt32)
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(64)
+ ->Arg(256)
+ ->Arg(1024);
+BENCHMARK(BM_ScatterNdUpdateInt64)
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(64)
+ ->Arg(256)
+ ->Arg(1024);
+
+BENCHMARK(BM_ScatterNdAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterNdAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 6e076a092e..aff547f78a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -4382,6 +4382,83 @@ output_min: This value is copied from input_min.
output_max: This value is copied from input_max.
)Doc");
+REGISTER_OP("ScatterNd")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Input("shape: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Doc(
+ R"doc(Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices.
+This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.
+
+TODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.
+
+`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `shape`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `shape`.
+
+`updates` is Tensor of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].
+```
+
+The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../../images/ScatterNd1.png" alt>
+</div>
+
+In Python, this scatter operation would look like this:
+
+ indices = tf.constant([[4], [3], [1], [7]])
+ updates = tf.constant([9, 10, 11, 12])
+ shape = tf.constant([8])
+ scatter = tf.scatter_nd(indices, updates, shape)
+ with tf.Session() as sess:
+ print sess.run(scatter)
+
+The resulting tensor would look like this:
+
+ [0, 11, 0, 10, 9, 0, 0, 12]
+
+We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../../images/ScatterNd2.png" alt>
+</div>
+
+In Python, this scatter operation would look like this:
+
+ indices = tf.constant([[0], [2]])
+ updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
+ [7, 7, 7, 7], [8, 8, 8, 8]],
+ [[5, 5, 5, 5], [6, 6, 6, 6],
+ [7, 7, 7, 7], [8, 8, 8, 8]]])
+ shape = tf.constant([4, 4, 4])
+ scatter = tf.scatter_nd(indices, updates, shape)
+ with tf.Session() as sess:
+ print sess.run(scatter)
+
+The resulting tensor would look like this:
+
+ [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
+ [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
+
+indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
+updates: A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref.
+shape: A vector. The shape of the resulting tensor.
+output: A new tensor with the given shape and updates applied according to the indices.)doc");
+
REGISTER_OP("FakeQuantWithMinMaxArgs")
.Attr("min: float = -6.0")
.Attr("max: float = 6.0")
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index b9ac8b16ff..9339b9b821 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -445,6 +445,241 @@ use_locking: If True, the operation will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
+REGISTER_OP("ScatterNdUpdate")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .Doc(
+ R"doc(Applies sparse `updates` to individual values or slices within a given variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
+
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_update(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+
+The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
+
+ref: A mutable Tensor. Should be from a Variable node.
+indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
+updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
+use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
+output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
+
+REGISTER_OP("ScatterNdAdd")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(
+ R"doc(Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:
+
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1], [7]])
+ updates = tf.constant([9, 10, 11, 12])
+ add = tf.scatter_nd_add(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(add)
+
+The resulting update to ref would look like this:
+
+ [1, 13, 3, 14, 14, 6, 7, 20]
+
+See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
+
+ref: A mutable Tensor. Should be from a Variable node.
+indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
+updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
+use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
+output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
+
+REGISTER_OP("ScatterNdSub")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(
+ R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:
+
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1], [7]])
+ updates = tf.constant([9, 10, 11, 12])
+ sub = tf.scatter_nd_sub(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(sub)
+
+The resulting update to ref would look like this:
+
+ [1, -9, 3, -6, -4, 6, 7, -4]
+
+See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
+
+ref: A mutable Tensor. Should be from a Variable node.
+indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
+updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
+use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
+output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
+
+REGISTER_OP("ScatterNdMul")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(
+ R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:
+
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1], [7]])
+ updates = tf.constant([9, 10, 11, 12])
+ sub = tf.scatter_nd_mul(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(sub)
+
+The resulting update to ref would look like this:
+
+ [1, 22, 3, 40, 45, 6, 7, 96]
+
+See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
+
+ref: A mutable Tensor. Should be from a Variable node.
+indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
+updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
+use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
+output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
+
+REGISTER_OP("ScatterNdDiv")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(
+ R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:
+
+ ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])
+ indices = tf.constant([[4], [3], [1], [7]])
+ updates = tf.constant([2, 3, 4, 5])
+ sub = tf.scatter_nd_div(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(sub)
+
+The resulting update to ref would look like this:
+
+ [10, 5, 30, 13, 25, 60, 70, 16]
+
+See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
+
+ref: A mutable Tensor. Should be from a Variable node.
+indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
+updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
+use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
+output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
+
REGISTER_OP("CountUpTo")
.Input("ref: Ref(T)")
.Output("output: T")
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index baa48ec8a4..67a086d4fc 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -263,6 +263,13 @@ tf_py_test(
)
tf_py_test(
+ name = "scatter_nd_ops_test",
+ size = "medium",
+ srcs = ["scatter_nd_ops_test.py"],
+ additional_deps = ["//tensorflow:tensorflow_py"],
+)
+
+tf_py_test(
name = "segment_reduction_ops_test",
size = "small",
srcs = ["segment_reduction_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index f3fd47381a..13b3bec3c0 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -53,20 +53,20 @@ class GatherNdTest(tf.test.TestCase):
gather_nd_ok_t = tf.gather_nd(params, indices_empty)
gather_nd_ok_val = gather_nd_ok_t.eval()
self.assertEqual([0], gather_nd_ok_t.get_shape())
- self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
+ self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
indices_empty = np.empty((0, 1), dtype=np.int32)
gather_nd_ok_t = tf.gather_nd(params, indices_empty)
gather_nd_ok_val = gather_nd_ok_t.eval()
self.assertEqual([0, 3], gather_nd_ok_t.get_shape())
- self.assertAllEqual(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val)
+ self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val)
params_empty = np.empty((0, 3), dtype=np.float32)
indices_empty = np.empty((0, 2), dtype=np.int32)
gather_nd_ok_t = tf.gather_nd(params_empty, indices_empty)
gather_nd_ok_val = gather_nd_ok_t.eval()
self.assertEqual([0], gather_nd_ok_t.get_shape())
- self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
+ self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
params_empty = np.empty((0, 3), dtype=np.float32)
indices_nonempty = np.zeros((1, 2), dtype=np.int32)
@@ -74,7 +74,7 @@ class GatherNdTest(tf.test.TestCase):
with self.assertRaisesOpError(
r"Requested more than 0 entries, but params is empty."):
gather_nd_break_t.eval()
- self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
+ self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
def testIndexScalar(self):
with self.test_session(use_gpu=self.use_gpu):
@@ -184,11 +184,11 @@ class GatherNdTest(tf.test.TestCase):
indices = tf.placeholder(tf.int32)
gather_nd_t = tf.gather_nd(params, indices)
shape = gather_nd_t.get_shape()
- self.assertEqual(shape.ndims, None)
- self.assertEqual(shape[0].value, None)
+ self.assertEqual(None, shape.ndims)
+ self.assertEqual(None, shape[0].value)
def testBadIndices(self):
- with self.test_session(use_gpu=False):
+ with self.test_session():
params = [0, 1, 2]
indices = [[[0], [7]]] # Make this one higher rank
gather_nd = tf.gather_nd(params, indices)
@@ -198,7 +198,7 @@ class GatherNdTest(tf.test.TestCase):
gather_nd.eval()
def testBadIndicesWithSlices(self):
- with self.test_session(use_gpu=False):
+ with self.test_session():
params = [[0, 1, 2]]
indices = [[[0], [0], [1]]] # Make this one higher rank
gather_nd = tf.gather_nd(params, indices)
@@ -207,6 +207,62 @@ class GatherNdTest(tf.test.TestCase):
r"\(shape: \[1,3\]\)"):
gather_nd.eval()
+ def testGradientsRank2Elements(self):
+ indices = tf.constant([[0, 0], [1, 1]], dtype=tf.int32)
+ inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float64)
+ outputs = tf.gather_nd(inputs, indices)
+
+ grad_vals = tf.constant([1, 2], dtype=tf.float64)
+ grads = tf.gradients([outputs], [inputs], [grad_vals])[0]
+ expected_grads = np.array([[1, 0], [0, 2]], dtype=np.float64)
+ with self.test_session():
+ assert np.array_equal(expected_grads, grads.eval())
+
+ def testGradientsRank2Slices(self):
+ indices = tf.constant([[1], [0]], dtype=tf.int32)
+ inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float64)
+ outputs = tf.gather_nd(inputs, indices)
+
+ grad_vals = tf.constant([[1, 2], [3, 4]], dtype=tf.float64)
+ grads = tf.gradients([outputs], [inputs], [grad_vals])[0]
+ expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
+ with self.test_session():
+ self.assertAllEqual(expected_grads, grads.eval())
+
+ def testGradientsRank3Elements(self):
+ indices = tf.constant([[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=tf.int32)
+ inputs = tf.constant([[[1, 3], [5, 7]], [[2, 4], [6, 8]]], dtype=tf.float64)
+ outputs = tf.gather_nd(inputs, indices)
+
+ grad_vals = tf.constant(
+ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=tf.float64)
+ grads = tf.gradients([outputs], [inputs], [grad_vals])[0]
+ expected_grads = np.array(
+ [[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64)
+ with self.test_session():
+ self.assertAllEqual(expected_grads, grads.eval())
+
+ def testGradientsRank2SlicesWithEmptySpace(self):
+ indices = tf.constant([[2], [0], [5]], dtype=tf.int32)
+ inputs = tf.constant(
+ [[1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9]],
+ dtype=tf.float64)
+ outputs = tf.gather_nd(inputs, indices)
+ grad_vals = tf.constant(
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2],
+ [3, 3, 3, 3, 3, 3, 3, 3, 3]],
+ dtype=tf.float64)
+ grads = tf.gradients([outputs], [inputs], [grad_vals])[0]
+ expected_grads = np.array(
+ [[2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]],
+ dtype=np.float64)
+ with self.test_session():
+ self.assertAllEqual(expected_grads, grads.eval())
+
class GatherNdGpuTest(GatherNdTest):
use_gpu = True
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
new file mode 100644
index 0000000000..2ad7ae810c
--- /dev/null
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -0,0 +1,383 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.tf.scatter_nd."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+from operator import add
+from operator import mul
+from operator import sub
+
+import numpy as np
+import tensorflow as tf
+
+
+def _AsType(v, vtype):
+ return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
+
+
+def _FlatInnerDims(tensor, ndims=2):
+ shape = list(tensor.shape)
+ return tensor.reshape([functools.reduce(mul, shape[:-ndims + 1], 1)] + shape[
+ -ndims + 1:])
+
+
+def _FlatOuterDims(tensor, ndims=2):
+ shape = list(tensor.shape)
+ return tensor.reshape(shape[:ndims - 1] +
+ [functools.reduce(mul, shape[ndims - 1:], 1)])
+
+
+def _NumpyScatterNd(ref, indices, updates, op):
+ ixdim = indices.shape[-1]
+ num_updates = indices.size / ixdim
+ total_nd = len(ref.shape)
+ slice_size = 1
+ for i in range(ixdim, total_nd):
+ slice_size *= ref.shape[i]
+ flat_indices = _FlatInnerDims(indices)
+ flat_updates = updates.reshape((num_updates, slice_size))
+ output_flat = _FlatOuterDims(ref, ixdim + 1)
+ for ix_updates, ix_output in enumerate(flat_indices):
+ ix_output = tuple(ix_output)
+ output_flat[ix_output] = op(output_flat[ix_output],
+ flat_updates[ix_updates])
+ return output_flat.reshape(ref.shape)
+
+
+def _NumpyUpdate(ref, indices, updates):
+ return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
+
+
+def _NumpyAdd(ref, indices, updates):
+ return _NumpyScatterNd(ref, indices, updates, add)
+
+
+def _NumpySub(ref, indices, updates):
+ return _NumpyScatterNd(ref, indices, updates, sub)
+
+
+def _NumpyMul(ref, indices, updates):
+ return _NumpyScatterNd(ref, indices, updates, mul)
+
+
+def _NumpyDiv(ref, indices, updates):
+ return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
+
+
+class ScatterTest(tf.test.TestCase):
+
+ def _VariableRankTest(self,
+ np_scatter,
+ tf_scatter,
+ vtype,
+ itype,
+ use_gpu,
+ repeat_indices=False):
+ np.random.seed(8)
+ ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
+ indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
+ with self.test_session(use_gpu=use_gpu):
+ for ref_shape, indices_shape in zip(ref_shapes, indices_shapes):
+ num_updates = indices_shape[0]
+ ixdim = indices_shape[-1]
+
+ indexable_area_shape = ()
+ for i in range(ixdim):
+ indexable_area_shape += (ref_shape[i],)
+ all_indices = [
+ list(coord)
+ for coord, _ in np.ndenumerate(
+ np.empty(indexable_area_shape, vtype))
+ ]
+ np.random.shuffle(all_indices)
+ indices = np.array(all_indices[:num_updates])
+
+ if num_updates > 1 and repeat_indices:
+ indices = indices[:num_updates // 2]
+ for _ in range(num_updates - num_updates // 2):
+ indices = np.append(
+ indices, [indices[np.random.randint(num_updates // 2)]], axis=0)
+ np.random.shuffle(indices)
+ indices = _AsType(indices[:num_updates], itype)
+
+ updates_shape = (num_updates,)
+ for i in range(ixdim, len(ref_shape)):
+ updates_shape += (ref_shape[i],)
+ updates = _AsType(np.random.randn(*(updates_shape)), vtype)
+ ref = _AsType(np.random.randn(*(ref_shape)), vtype)
+
+ # Scatter via numpy
+ new = ref.copy()
+ np_scatter(new, indices, updates)
+ # Scatter via tensorflow
+ ref_var = tf.Variable(ref)
+ ref_var.initializer.run()
+ tf_scatter(ref_var, indices, updates).eval()
+ # Compare
+ self.assertAllClose(new, ref_var.eval())
+
+ def _VariableRankTests(self, np_scatter, tf_scatter):
+ for vtype in (np.float32, np.float64):
+ for itype in (np.int32, np.int64):
+ for use_gpu in (False, True):
+ self._VariableRankTest(np_scatter, tf_scatter, vtype, itype, use_gpu)
+
+ def testVariableRankUpdate(self):
+ self._VariableRankTests(_NumpyUpdate, tf.scatter_nd_update)
+
+ def testVariableRankAdd(self):
+ self._VariableRankTests(_NumpyAdd, tf.scatter_nd_add)
+
+ def testVariableRankSub(self):
+ self._VariableRankTests(_NumpySub, tf.scatter_nd_sub)
+
+ def testVariableRankMul(self):
+ self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul)
+
+ def testVariableRankDiv(self):
+ self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div)
+
+ def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
+ for vtype in (np.float32, np.float64):
+ for itype in (np.int32, np.int64):
+ for use_gpu in (False, True):
+ self._VariableRankTest(
+ np_scatter,
+ tf_scatter,
+ vtype,
+ itype,
+ use_gpu,
+ repeat_indices=True)
+
+ def testScatterRepeatIndices(self):
+ """This tests scatter_add using indices that repeat."""
+ self._ScatterRepeatIndicesTest(_NumpyAdd, tf.scatter_nd_add)
+ self._ScatterRepeatIndicesTest(_NumpySub, tf.scatter_nd_sub)
+ self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul)
+ self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div)
+
+ def testBooleanScatterUpdate(self):
+ with self.test_session(use_gpu=False) as session:
+ var = tf.Variable([True, False])
+ update0 = tf.scatter_nd_update(var, [[1]], [True])
+ update1 = tf.scatter_nd_update(
+ var, tf.constant(
+ [[0]], dtype=tf.int64), [False])
+ var.initializer.run()
+
+ session.run([update0, update1])
+
+ self.assertAllEqual([False, True], var.eval())
+
+ def testScatterOutOfRangeCpu(self):
+ for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul,
+ tf.scatter_nd_div, tf.scatter_nd_update):
+ params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
+ updates = np.array([-3, -4, -5]).astype(np.float32)
+ with self.test_session(use_gpu=False):
+ ref = tf.Variable(params)
+ ref.initializer.run()
+
+ # Indices all in range, no problem.
+ indices = np.array([[2], [0], [5]])
+ op(ref, indices, updates).eval()
+
+ # Test some out of range errors.
+ indices = np.array([[-1], [0], [5]])
+ with self.assertRaisesOpError(
+ r"Invalid indices: \[0,0\] = \[-1\] is not in \[0, 6\)"):
+ op(ref, indices, updates).eval()
+
+ indices = np.array([[2], [0], [6]])
+ with self.assertRaisesOpError(
+ r"Invalid indices: \[2,0\] = \[6\] is not in \[0, 6\)"):
+ op(ref, indices, updates).eval()
+
+ def testRank3ValidShape(self):
+ indices = tf.zeros([2, 2, 2], tf.int32)
+ updates = tf.zeros([2, 2, 2], tf.int32)
+ shape = np.array([2, 2, 2])
+ self.assertAllEqual(
+ tf.scatter_nd(indices, updates, shape).get_shape().as_list(), shape)
+
+ ref = tf.Variable(tf.zeros(shape, tf.int32))
+ self.assertAllEqual(
+ tf.scatter_nd_update(ref, indices, updates).get_shape().as_list(),
+ shape)
+
+ def testUndefinedIndicesShape(self):
+ indices = tf.placeholder(tf.int32, shape=None)
+ updates = tf.placeholder(tf.int32, shape=[2, 2, 2])
+ shape = tf.constant([2, 2, 2], tf.int32)
+ tf.scatter_nd(indices, updates, shape)
+
+ def testUndefinedUpdatesShape(self):
+ indices = tf.placeholder(tf.int32, shape=[2, 2, 2])
+ updates = tf.placeholder(tf.int32, shape=None)
+ shape = tf.constant([2, 2, 2], tf.int32)
+ tf.scatter_nd(indices, updates, shape)
+
+ def testUndefinedOutputShape(self):
+ indices = tf.placeholder(tf.int32, shape=[2, 2, 2])
+ updates = tf.placeholder(tf.int32, shape=[2, 2, 2])
+ shape = tf.placeholder(tf.int32, shape=[None])
+ tf.scatter_nd(indices, updates, shape)
+
+ def testEmptyoutputShape1(self):
+ indices = tf.zeros([2, 2, 2], tf.int32)
+ updates = tf.zeros([2, 2, 2], tf.int32)
+ shape = tf.constant([0, 3, 2], tf.int32)
+
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, "Indices and updates specified for empty output shape"):
+ tf.scatter_nd(indices, updates, shape)
+
+ def testEmptyoutputShape2(self):
+ indices = tf.placeholder(tf.int32, shape=None)
+ updates = tf.placeholder(tf.int32, shape=None)
+ shape = tf.constant([0, 3, 2], tf.int32)
+
+ with self.test_session():
+ #with self.assertRaisesOpError(
+ # r"Indices and updates specified for empty output shape"):
+ tf.scatter_nd(indices, updates, shape).eval(feed_dict={
+ indices: np.zeros(
+ [2, 2, 2], dtype=np.int32),
+ updates: np.zeros(
+ [2, 2, 2], dtype=np.int32)
+ })
+
+ def testEmptyoutputShape3(self):
+ indices = tf.zeros([0], tf.int32)
+ updates = tf.zeros([0], tf.int32)
+ shape = tf.constant([0], tf.int32)
+ scatter = tf.scatter_nd(indices, updates, shape)
+
+ with self.test_session():
+ self.assertEqual(scatter.eval().size, 0)
+
+ def testRank3InvalidShape1(self):
+ indices = tf.zeros([3, 2, 2], tf.int32)
+ updates = tf.zeros([2, 2, 2], tf.int32)
+ shape = np.array([2, 2, 2])
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, "The outer \d+ dimensions of indices\.shape="):
+ tf.scatter_nd(indices, updates, shape)
+
+ ref = tf.Variable(tf.zeros(shape, tf.int32))
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, "The outer \d+ dimensions of indices\.shape="):
+ tf.scatter_nd_update(ref, indices, updates)
+
+ def testRank3InvalidShape2(self):
+ indices = tf.zeros([2, 2, 1], tf.int32)
+ updates = tf.zeros([2, 2], tf.int32)
+ shape = np.array([2, 2, 2])
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, "The inner \d+ dimensions of output.shape="):
+ tf.scatter_nd(indices, updates, shape)
+
+ ref = tf.Variable(tf.zeros(shape, tf.int32))
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, "The inner \d+ dimensions of ref\.shape="):
+ tf.scatter_nd_update(ref, indices, updates)
+
+ def testGradientsRank2ElementUpdate(self):
+ indices = tf.constant([[0, 0], [1, 1]], dtype=tf.int32)
+ updates = tf.constant([1, 4], dtype=tf.float64)
+ shape = tf.constant([2, 2], dtype=tf.int32)
+ outputs = tf.scatter_nd(indices, updates, shape)
+
+ grad_vals = tf.constant([[1, 2], [3, 4]], dtype=tf.float64)
+ grads = tf.gradients([outputs], [updates], [grad_vals])[0]
+ expected_grads = np.array([1, 4], dtype=np.float64)
+ with self.test_session():
+ self.assertAllEqual(expected_grads, grads.eval())
+
+ def testGradientsRank2SliceUpdate(self):
+ indices = tf.constant([[1], [0]], dtype=tf.int32)
+ updates = tf.constant([[3, 4], [1, 2]], dtype=tf.float64)
+ shape = tf.constant([2, 2], dtype=tf.int32)
+ outputs = tf.scatter_nd(indices, updates, shape)
+
+ grad_vals = tf.constant([[3, 4], [1, 2]], dtype=tf.float64)
+ grads = tf.gradients([outputs], [updates], [grad_vals])[0]
+ expected_grads = np.array([[1, 2], [3, 4]], dtype=np.float64)
+ with self.test_session():
+ self.assertAllEqual(expected_grads, grads.eval())
+
+ def testGradientsRank3SliceUpdate(self):
+ indices = tf.constant([[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=tf.int32)
+ updates = tf.constant(
+ [[[5, 7], [2, 4]], [[1, 3], [6, 8]]], dtype=tf.float64)
+ shape = tf.constant([2, 2, 2], dtype=tf.int32)
+ outputs = tf.scatter_nd(indices, updates, shape)
+
+ grad_vals = tf.constant(
+ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=tf.float64)
+ grads = tf.gradients([outputs], [updates], [grad_vals])[0]
+ expected_grads = np.array(
+ [[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
+ with self.test_session():
+ self.assertAllEqual(expected_grads, grads.eval())
+
+ def testConcurrentUpdates(self):
+ num_updates = 10000
+ update_values = np.random.rand(num_updates)
+ ref = tf.Variable(np.zeros([2, 2]), dtype=tf.float64)
+ indices = tf.constant([[0, 1]] * num_updates, dtype=tf.int32)
+ updates = tf.constant(update_values, dtype=tf.float64)
+
+ exepected_result = np.zeros([2, 2], dtype=np.float64)
+ exepected_result[0, 1] = np.sum(update_values)
+
+ scatter = tf.scatter_nd_add(ref, indices, updates)
+ init = tf.initialize_all_variables()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ result = sess.run(scatter)
+ assert np.allclose(result, exepected_result)
+
+ # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
+ def _disabledTestScatterOutOfRangeGpu(self):
+ if not tf.test.IsBuiltWithCuda():
+ return
+ for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul,
+ tf.scatter_nd_div, tf.scatter_nd_update):
+ params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
+ updates = np.array([-3, -4, -5]).astype(np.float32)
+ # With GPU, the code ignores indices that are out of range.
+ # We don't test the implementation; just test there's no failures.
+ with self.test_session(force_gpu=True):
+ ref = tf.Variable(params)
+ ref.initializer.run()
+
+ # Indices all in range, no problem.
+ indices = np.array([2, 0, 5])
+ op(ref, indices, updates).eval()
+
+ # Indicies out of range should not fail.
+ indices = np.array([-1, 0, 5])
+ op(ref, indices, updates).eval()
+ indices = np.array([2, 0, 6])
+ op(ref, indices, updates).eval()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 40f4ceb69d..b97fbff644 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -301,8 +301,12 @@ def _GatherGrad(op, grad):
@ops.RegisterGradient("GatherNd")
-def _GatherNdGrad(unused_op, unused_grad):
- raise NotImplementedError("Gradient for gather_nd is not implemented.")
+def _GatherNdGrad(op, grad):
+ ref = op.inputs[0]
+ ref_shape = array_ops.shape(ref)
+ indices = op.inputs[1]
+ ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
+ return [ref_grad, None]
@ops.RegisterGradient("CheckNumerics")
@@ -566,3 +570,10 @@ def _ExtractImagePatchesGrad(op, grad):
grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
return [grad_out]
+
+
+@ops.RegisterGradient("ScatterNd")
+def _ScatterNdGrad(op, grad):
+ indices = op.inputs[0]
+ updates_grad = array_ops.gather_nd(grad, indices)
+ return [None, updates_grad, None]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index e1f4ce2172..01e9fd9aa7 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -71,6 +71,7 @@ or join multiple tensors together.
@@gather
@@gather_nd
@@unique_with_counts
+@@scatter_nd
@@dynamic_partition
@@dynamic_stitch
@@boolean_mask
@@ -2461,3 +2462,50 @@ def _QuantizedReshapeShape(op):
ops.RegisterShape("QuantizeV2")(None)
ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")(None)
ops.RegisterShape("QuantizedConcat")(None)
+
+
+@ops.RegisterShape("ScatterNd")
+def _ScatterNdShape(op):
+ """Shape function for the ScatterNd op.
+
+ The shape of the ouput is defined as a parameter on the Operation.
+
+ Args:
+ op: A ScatterNd Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: if the arguments have invalid rank
+ """
+ indices_shape = op.inputs[0].get_shape()
+ updates_shape = op.inputs[1].get_shape()
+ shape = op.inputs[2]
+
+ output_shape = tensor_util.constant_value_as_shape(shape)
+ if output_shape.num_elements() == 0 and not (
+ indices_shape.num_elements() in
+ (None, 0) and updates_shape.num_elements() in (None, 0)):
+ raise ValueError("Indices and updates specified for empty output shape")
+
+ if indices_shape.ndims is not None and shape is not None:
+ outer_dims = len(indices_shape) - 1
+ ixdim = indices_shape[-1].value or 0
+
+ if not indices_shape[:outer_dims].is_compatible_with(
+ updates_shape[:outer_dims]):
+ raise ValueError("The outer %d dimensions of indices.shape=%s must " \
+ "match the outer %d dimensions of updates.shape=%s" % (
+ outer_dims, indices_shape, outer_dims,
+ updates_shape))
+ if output_shape.ndims is not None:
+ if not output_shape[ixdim:].is_compatible_with(updates_shape[
+ outer_dims:]):
+ raise ValueError("The inner %d dimensions of output.shape=%s must " \
+ "match the inner %d dimensions of updates.shape=%s" % (
+ len(output_shape)-ixdim, output_shape,
+ len(updates_shape)-outer_dims, updates_shape))
+
+ return [output_shape]
+ return [None]
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 9267b8ef2e..847d1b99c8 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -67,6 +67,11 @@ from tensorflow.python.ops.state_ops import scatter_div
from tensorflow.python.ops.state_ops import scatter_mul
from tensorflow.python.ops.state_ops import scatter_sub
from tensorflow.python.ops.state_ops import scatter_update
+from tensorflow.python.ops.state_ops import scatter_nd_add
+from tensorflow.python.ops.state_ops import scatter_nd_sub
+from tensorflow.python.ops.state_ops import scatter_nd_mul
+from tensorflow.python.ops.state_ops import scatter_nd_div
+from tensorflow.python.ops.state_ops import scatter_nd_update
from tensorflow.python.ops.string_ops import *
from tensorflow.python.ops.template import *
from tensorflow.python.ops.tensor_array_ops import *
diff --git a/tensorflow/python/ops/state_grad.py b/tensorflow/python/ops/state_grad.py
index 871ce780c5..314f9f0c1a 100644
--- a/tensorflow/python/ops/state_grad.py
+++ b/tensorflow/python/ops/state_grad.py
@@ -20,7 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
-from tensorflow.python.ops import state_ops
+
# TODO(b/31222613): These ops may be differentiable, and there may be
# latent bugs here.
@@ -43,3 +43,14 @@ ops.NotDifferentiable("ScatterMul")
ops.NotDifferentiable("ScatterDiv")
+
+
+ops.NotDifferentiable("ScatterNdUpdate")
+
+ops.NotDifferentiable("ScatterNdAdd")
+
+ops.NotDifferentiable("ScatterNdSub")
+
+ops.NotDifferentiable("ScatterNdMul")
+
+ops.NotDifferentiable("ScatterNdDiv")
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 636acc3e2a..eede28cdc9 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -95,6 +95,11 @@ automatically by the optimizers in most cases.
@@scatter_sub
@@scatter_mul
@@scatter_div
+@@scatter_nd_update
+@@scatter_nd_add
+@@scatter_nd_sub
+@@scatter_nd_mul
+@@scatter_nd_div
@@sparse_mask
@@IndexedSlices
@@ -209,3 +214,34 @@ ops.RegisterShape("ScatterDiv")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ScatterMul")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ScatterSub")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ScatterUpdate")(common_shapes.call_cpp_shape_fn)
+
+
+@ops.RegisterShape("ScatterNdAdd")
+@ops.RegisterShape("ScatterNdSub")
+@ops.RegisterShape("ScatterNdMul")
+@ops.RegisterShape("ScatterNdDiv")
+@ops.RegisterShape("ScatterNdUpdate")
+def scatter_nd_update_shape(op):
+ """Shape function for the ScatterNd update ops."""
+ ref_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ updates_shape = op.inputs[2].get_shape()
+
+ if indices_shape.ndims is not None and ref_shape.ndims is not None:
+ outer_dims = len(indices_shape) - 1
+ ixdim = indices_shape[-1].value or 0
+
+ if not indices_shape[:outer_dims].is_compatible_with(
+ updates_shape[:outer_dims]):
+ raise ValueError("The outer %d dimensions of indices.shape=%s must " \
+ "match the outer %d dimensions of updates.shape=%s" % (
+ outer_dims, indices_shape, outer_dims,
+ updates_shape))
+
+ if not ref_shape[ixdim:].is_compatible_with(updates_shape[outer_dims:]):
+ raise ValueError("The inner %d dimensions of ref.shape=%s must match " \
+ "the inner %d dimensions of updates.shape=%s" % (
+ len(ref_shape)-ixdim, ref_shape,
+ len(updates_shape)-outer_dims, updates_shape))
+
+ return [ref_shape]