diff options
-rw-r--r-- | tensorflow/contrib/makefile/tf_op_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 21 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.cc | 547 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.h | 52 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op_test.cc | 320 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 77 | ||||
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 235 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/gather_nd_op_test.py | 72 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/scatter_nd_ops_test.py | 383 | ||||
-rw-r--r-- | tensorflow/python/ops/array_grad.py | 15 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 48 | ||||
-rw-r--r-- | tensorflow/python/ops/standard_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/state_grad.py | 13 | ||||
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 36 |
15 files changed, 1821 insertions, 11 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index ed5d6539b3..ea1612201e 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -184,3 +184,4 @@ tensorflow/core/ops/control_flow_ops.cc tensorflow/core/ops/candidate_sampling_ops.cc tensorflow/core/ops/array_ops.cc tensorflow/core/ops/array_grad.cc +tensorflow/core/ops/scatter_nd_op.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 34954f0066..0a7a42fde6 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2025,11 +2025,13 @@ tf_kernel_libraries( "count_up_to_op", "dense_update_ops", "scatter_op", + "scatter_nd_op", "variable_ops", ], deps = [ ":assign_op", ":bounds_check", + ":fill_functor", ":scatter_functor", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -2043,6 +2045,7 @@ tf_cc_test( size = "small", srcs = ["scatter_op_test.cc"], deps = [ + ":fill_functor", ":ops_testutil", ":ops_util", ":scatter_op", @@ -2055,6 +2058,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "scatter_nd_op_test", + size = "small", + srcs = ["scatter_nd_op_test.cc"], + deps = [ + ":ops_testutil", + ":ops_util", + ":scatter_nd_op", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_libraries( name = "string", prefixes = [ @@ -2497,6 +2517,7 @@ filegroup( "debug_ops.*", # Ops excluded because they do not build correctly for Android. # See b/29213790 + "scatter_nd_op.*", "sparse_matmul_op.*", ], ), diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc new file mode 100644 index 0000000000..61098c7802 --- /dev/null +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -0,0 +1,547 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/state_ops.cc. +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/scatter_nd_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// Check whether updates.shape = indices.shape[0] + params.shape[IXDIM:] +static bool ValidUpdateShape(const TensorShape& params_shape, + const Tensor& indices, const Tensor& updates) { + int64 indices_nd = 1; + if (indices.dims() > 1) { + indices_nd = indices.dim_size(1); + } + for (int d = indices_nd; d < params_shape.dims(); d++) { + if (updates.dim_size(d - indices_nd + 1) != params_shape.dim_size(d)) { + return false; + } + } + return true; +} + +template <typename Index> +static void PrepareAndValidateInputs(OpKernelContext* c, + const TensorShape& params_shape, + const Tensor& indices, + const Tensor& updates, int64* indices_nd, + Index* num_updates, Index* slice_size) { + const TensorShape& indices_shape(indices.shape()); + const TensorShape& updates_shape(updates.shape()); + + OP_REQUIRES( + c, TensorShapeUtils::IsVectorOrHigher(params_shape), + errors::InvalidArgument("Output must be at least 1-D, ", "got shape: ", + params_shape.DebugString())); + + OP_REQUIRES(c, params_shape.num_elements() >= 0 || + (indices.NumElements() == 0 && updates.NumElements() == 0), + errors::InvalidArgument( + "Indices and updates specified for empty output", " shape")); + + OP_REQUIRES(c, updates.dim_size(0) == indices.dim_size(0), + errors::InvalidArgument( + "The outermost dimension of updates and indices ", + "must match. Got indices.shape ", indices_shape.DebugString(), + ", updates.shape ", updates_shape.DebugString())); + OP_REQUIRES( + c, ValidUpdateShape(params_shape, indices, updates), + errors::InvalidArgument( + "Must have updates.shape = indices.shape[0] + params_shape[IXDIM:], ", + "got updates.shape ", updates_shape.DebugString(), ", indices.shape ", + indices_shape.DebugString(), ", params_shape ", + params_shape.DebugString())); + // Check that we have enough index space + const int64 N_big = indices.NumElements(); + OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(), + errors::InvalidArgument( + "indices has too many elements for ", + DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ", + N_big, " > ", std::numeric_limits<Index>::max())); + OP_REQUIRES( + c, params_shape.dim_size(0) <= std::numeric_limits<Index>::max(), + errors::InvalidArgument("params_shape[0] too large for ", + DataTypeString(DataTypeToEnum<Index>::v()), + " indexing: ", params_shape.dim_size(0), " > ", + std::numeric_limits<Index>::max())); + + // Calculate the number of dimensions in indices + *indices_nd = 1; + if (indices_shape.dims() > 1) { + *indices_nd = indices_shape.dim_size(indices_shape.dims() - 1); + } + + // Calculate the number of elements that make up each slice of our updated + // tensor. This allows us to work with flattened tensors and copy over whole + // slices at a time. + Index total_nd = params_shape.dims(); + + int64 slice_size_big = 1; + for (int64 i = *indices_nd; i < total_nd; ++i) { + slice_size_big *= params_shape.dim_size(i); + } + + OP_REQUIRES(c, slice_size_big <= std::numeric_limits<Index>::max(), + errors::InvalidArgument("slice size is too large for indexing: ", + slice_size_big, " > ", + std::numeric_limits<Index>::max())); + + *slice_size = static_cast<Index>(slice_size_big); + + const int64 safe_indices_nd = (*indices_nd < 1) ? 1 : *indices_nd; + *num_updates = indices_shape.num_elements() / safe_indices_nd; +} + +template <typename Device, typename T, typename Index> +class ScatterNdOp : public OpKernel { + public: + explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) { + const DataType dt = DataTypeToEnum<T>::v(); + const DataType index_t = DataTypeToEnum<Index>::v(); + OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt})); + } + + void Compute(OpKernelContext* c) override { + const Tensor& indices = c->input(0); + const Tensor& updates = c->input(1); + const Tensor& shape_input = c->input(2); + + OP_REQUIRES(c, shape_input.dims() == 1, + errors::InvalidArgument("Shape must be a vector")); + auto vec = shape_input.flat<Index>(); + TensorShape shape; + TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape); + + int64 indices_nd; + Index num_updates; + Index slice_size; + PrepareAndValidateInputs<Index>(c, shape, indices, updates, &indices_nd, + &num_updates, &slice_size); + if (!c->status().ok()) return; + + Tensor scratch; + OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch)); + + auto scratch_scalar = scratch.scalar<Index>(); + auto indices_flat = indices.flat_inner_dims<Index>(); + auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size}); + + Index bad_i = -1; + switch (indices_nd) { +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + Tensor* out = nullptr; \ + OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \ + functor::SetZeroFunctor<Device, T> fill; \ + fill(c->eigen_device<Device>(), out->flat<T>()); \ + if (shape.num_elements() > 0) { \ + auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \ + functor::ScatterNdFunctor<Device, T, Index, \ + scatter_nd_op::UpdateOp::ADD, (IXDIM)> \ + functor; \ + bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \ + output_flat, indices_flat, updates_flat, output_flat); \ + } \ + } break + PARAMS_CASE(0); + PARAMS_CASE(1); + PARAMS_CASE(2); + PARAMS_CASE(3); + PARAMS_CASE(4); + PARAMS_CASE(5); +#undef PARAMS_CASE + default: + OP_REQUIRES(c, false, + errors::InvalidArgument( + "Only indices.shape[-1] values between 0 and 5 " + "are currently supported. Requested rank: ", + indices_nd)); + } + OP_REQUIRES( + c, bad_i < 0, + errors::InvalidArgument( + "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), + " = [", str_util::Join(gtl::ArraySlice<Index>( + &indices_flat(bad_i, 0), indices_nd), + ", "), + "] does not index into ", shape.DebugString())); + } +}; + +template <typename Device, typename T, typename Index, + scatter_nd_op::UpdateOp op> +class ScatterNdUpdateOp : public OpKernel { + public: + explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) { + const DataType dt = DataTypeToEnum<T>::v(); + const DataType dt_ref = DataTypeToEnum<T>::ref(); + const DataType index_t = DataTypeToEnum<Index>::v(); + OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref})); + OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* c) override { + if (use_exclusive_lock_) { + // Hold mutex while we apply updates + mutex_lock l(*c->input_ref_mutex(0)); + DoCompute(c); + } else { + DoCompute(c); + } + } + + private: + bool use_exclusive_lock_; + + void DoCompute(OpKernelContext* c) { + Tensor params = c->mutable_input(0, use_exclusive_lock_); + const Tensor& indices = c->input(1); + const Tensor& updates = c->input(2); + const TensorShape& params_shape(params.shape()); + + int64 indices_nd; + Index num_updates; + Index slice_size; + + OP_REQUIRES(c, params.IsInitialized(), + errors::FailedPrecondition("Null ref for params")); + PrepareAndValidateInputs<Index>(c, params_shape, indices, updates, + &indices_nd, &num_updates, &slice_size); + if (!c->status().ok()) return; + + Tensor scratch; + OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch)); + + auto scratch_scalar = scratch.scalar<Index>(); + auto indices_flat = indices.flat_inner_dims<Index>(); + auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size}); + + Index bad_i = -1; + c->forward_ref_input_to_ref_output(0, 0); + switch (indices_nd) { +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \ + functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \ + bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \ + params_flat, indices_flat, updates_flat, params_flat); \ + } break + PARAMS_CASE(0); + PARAMS_CASE(1); + PARAMS_CASE(2); + PARAMS_CASE(3); + PARAMS_CASE(4); + PARAMS_CASE(5); +#undef PARAMS_CASE + default: + OP_REQUIRES(c, false, + errors::InvalidArgument( + "Only indices.shape[-1] values between 1 and 5 " + "are currently supported. Requested rank: ", + indices_nd)); + } + OP_REQUIRES( + c, bad_i < 0, + errors::InvalidArgument( + "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), + " = [", str_util::Join(gtl::ArraySlice<Index>( + &indices_flat(bad_i, 0), indices_nd), + ", "), + "] is not in [0, ", params.dim_size(0), ")")); + } +}; + +// Specialization of ScatterNdSliceGenerator to CPU +namespace generator { + +template <typename T, typename Index, scatter_nd_op::UpdateOp op> +class UpdateExecutor { + public: + static void Update(T* input, const T* updates, T* output, Index slice_size); +}; + +template <typename T, typename Index> +class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ASSIGN> { + public: + static void Update(T* /* unused */, const T* updates, T* output, + Index slice_size) { + std::copy_n(updates, slice_size, output); + } +}; + +template <typename T, typename Index> +class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ADD> { + public: + static void Update(T* input, const T* updates, T* output, Index slice_size) { + std::transform(input, input + slice_size, updates, output, std::plus<T>()); + } +}; + +template <typename T, typename Index> +class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::SUB> { + public: + static void Update(T* input, const T* updates, T* output, Index slice_size) { + std::transform(input, input + slice_size, updates, output, std::minus<T>()); + } +}; + +template <typename T, typename Index> +class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::MUL> { + public: + static void Update(T* input, const T* updates, T* output, Index slice_size) { + std::transform(input, input + slice_size, updates, output, + std::multiplies<T>()); + } +}; + +template <typename T, typename Index> +class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::DIV> { + public: + static void Update(T* input, const T* updates, T* output, Index slice_size) { + std::transform(input, input + slice_size, updates, output, + std::divides<T>()); + } +}; + +template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM> +class ScatterNdSliceGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ScatterNdSliceGenerator( + const Index slice_size, typename TTypes<T, IXDIM + 1>::Tensor Tparams, + typename TTypes<Index, 2>::ConstTensor Tindices, + typename TTypes<T, 2>::ConstTensor Tupdates, + typename TTypes<T, IXDIM + 1>::Tensor Toutput, + std::atomic<Index>* error_loc) + : slice_size_(slice_size), + Tparams_(Tparams), + Tindices_(Tindices), + Tupdates_(Tupdates), + Toutput_(Toutput), + error_loc_(error_loc) {} + + EIGEN_DEVICE_FUNC bool GenerateIndices( + const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const { + (*ix)[IXDIM] = 0; + bool out_of_bounds = false; + for (int i = 0; i < IXDIM; ++i) { + const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); + (*ix)[i] = ix_i; + out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); + } + return out_of_bounds; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 + operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const { + auto loc = loc_array[0]; + Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix_params; + Eigen::array<Eigen::DenseIndex, 2> ix_updates; + ix_updates[0] = loc; + ix_updates[1] = 0; + const bool out_of_bounds = GenerateIndices(loc, &ix_params); + if (TF_PREDICT_FALSE(out_of_bounds)) { + error_loc_->store(loc); + } else { + UpdateExecutor<T, Index, op>::Update(&Tparams_(ix_params), + &Tupdates_(ix_updates), + &Toutput_(ix_params), slice_size_); + } + return static_cast<int32>(0); // Return something... + } + + protected: + const Index slice_size_; + mutable typename TTypes<T, IXDIM + 1>::Tensor Tparams_; + const typename TTypes<Index, 2>::ConstTensor Tindices_; + const typename TTypes<T, 2>::ConstTensor Tupdates_; + mutable typename TTypes<T, IXDIM + 1>::Tensor Toutput_; + std::atomic<Index>* error_loc_; +}; + +} // namespace generator + +namespace functor { +// Implementation of update functor for CPU. +template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM> +struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> { + Index operator()(const CPUDevice& d, const Index slice_size, + typename TTypes<Index>::Scalar Tscratch, + typename TTypes<T, IXDIM + 1>::Tensor Tparams, + typename TTypes<Index, 2>::ConstTensor Tindices, + typename TTypes<T, 2>::ConstTensor Tupdates, + typename TTypes<T, IXDIM + 1>::Tensor Toutput) { + std::atomic<Index> error_loc(-1); + + const Eigen::DenseIndex batch_size = Tindices.dimension(0); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }}; + Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }}; +#else + Eigen::IndexList<Eigen::type2index<1> > reshape_dims; + Eigen::IndexList<Eigen::DenseIndex> broadcast_dims; + broadcast_dims.set(0, batch_size); +#endif + + generator::ScatterNdSliceGenerator<T, Index, op, IXDIM> generator( + slice_size, Tparams, Tindices, Tupdates, Toutput, &error_loc); + Tscratch.device(d) = Tscratch.reshape(reshape_dims) + .broadcast(broadcast_dims) + .generate(generator) + .sum(); + + // error_loc() returns -1 if there's no out-of-bounds index, + // otherwise it returns the location of an OOB index in Tindices. + return error_loc.load(); + } +}; +} // namespace functor + +#define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \ + REGISTER_KERNEL_BUILDER(Name(name) \ + .Device(DEVICE_##dev) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + ScatterNdOp<dev##Device, type, index_type>) + +#define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \ + op) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_##dev) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + ScatterNdUpdateOp<dev##Device, type, index_type, op>) + +#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \ + REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \ + REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name) + +#define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \ + REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \ + REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op) + +#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ + scatter_nd_op::UpdateOp::ADD); \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \ + scatter_nd_op::UpdateOp::SUB); \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \ + scatter_nd_op::UpdateOp::MUL); \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \ + scatter_nd_op::UpdateOp::DIV); + +#define REGISTER_SCATTER_ND(type, dev) \ + REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd"); + +#define REGISTER_SCATTER_ND_UPDATE(type, dev) \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \ + scatter_nd_op::UpdateOp::ASSIGN); + +// Registers CPU kernels. +#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \ + REGISTER_SCATTER_ND_ADD_SUB(type, CPU); + +#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \ + REGISTER_SCATTER_ND_UPDATE(type, CPU); + +#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU); + +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU); +TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); +TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU); + +// Registers GPU kernels. +#if GOOGLE_CUDA +#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \ + REGISTER_SCATTER_ND_ADD_SUB(type, GPU); + +#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \ + REGISTER_SCATTER_ND_UPDATE(type, GPU); + +// TODO(simister): Re-enable when GPU support is working. +// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_GPU); +// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_GPU); + +#endif // GOOGLE_CUDA + +#undef REGISTER_SCATTER_ND_ADD +#undef REGISTER_SCATTER_ND_ADD_SUB +#undef REGISTER_SCATTER_ND_ADD_SUB_CPU +#undef REGISTER_SCATTER_ND_ADD_SUB_GPU +#undef REGISTER_SCATTER_ND_UPDATE +#undef REGISTER_SCATTER_ND_UPDATE_CPU +#undef REGISTER_SCATTER_ND_UPDATE_GPU +#undef REGISTER_SCATTER_ND_KERNEL +#undef REGISTER_SCATTER_ND_KERNEL_INDEX + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { + +#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \ + template <> \ + Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes<T, IXDIM>::Tensor params, \ + typename TTypes<Index, 2>::ConstTensor indices, \ + typename TTypes<T, 2>::ConstTensor updates); \ + extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>; + +#define DECLARE_GPU_SPECS_OPS(T, Index, op) \ + DECLARE_GPU_SPECS_OP(T, Index, op, 0); \ + DECLARE_GPU_SPECS_OP(T, Index, op, 1); \ + DECLARE_GPU_SPECS_OP(T, Index, op, 2); \ + DECLARE_GPU_SPECS_OP(T, Index, op, 3); \ + DECLARE_GPU_SPECS_OP(T, Index, op, 4); \ + DECLARE_GPU_SPECS_OP(T, Index, op, 5) + +#define DECLARE_GPU_SPECS_INDEX(T, Index) \ + DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ + DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \ + DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \ + DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \ + DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV); + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPECS_INDEX(T, int32); \ + DECLARE_GPU_SPECS_INDEX(T, int64); + +// TODO(simister): Re-enable when GPU support is working. +// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS); + +#undef DECLARE_GPU_SPECS +#undef DECLARE_GPU_SPECS_INDEX +#undef DECLARE_GPU_SPECS_OPS +#undef DECLARE_GPU_SPECS_OP + +} // namespace functor +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_nd_op.h b/tensorflow/core/kernels/scatter_nd_op.h new file mode 100644 index 0000000000..e4c8e7ed9f --- /dev/null +++ b/tensorflow/core/kernels/scatter_nd_op.h @@ -0,0 +1,52 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_SCATTER_ND_OP_H_ +#define TENSORFLOW_KERNELS_SCATTER_ND_OP_H_ + +// Functor definitions for ScatterND ops, must be compilable by nvcc. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace scatter_nd_op { + +enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV }; + +} // namespace scatter_nd_op + +namespace functor { + +// Functor used by ScatterOp to do the computations. +template <typename Device, typename T, typename Index, + scatter_nd_op::UpdateOp op, int IXDIM> +struct ScatterNdFunctor { + // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index. + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes<Index>::Scalar Tscratch, + typename TTypes<T, IXDIM>::Tensor params, + typename TTypes<T, 2>::ConstTensor indices, + typename TTypes<T, 2>::ConstTensor updates, + typename TTypes<T, IXDIM + 1>::Tensor Toutput); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SCATTER_ND_OP_H_ diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc new file mode 100644 index 0000000000..d6743a6867 --- /dev/null +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -0,0 +1,320 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <functional> +#include <memory> +#include <vector> + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +class ScatterNdUpdateOpTest : public OpsTestBase { + protected: + void MakeOp(DataType variable_ref_type, DataType index_type) { + TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate") + .Input(FakeInput(variable_ref_type)) + .Input(FakeInput(index_type)) + .Input(FakeInput(RemoveRefType(variable_ref_type))) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +TEST_F(ScatterNdUpdateOpTest, Simple_StringType) { + MakeOp(DT_STRING_REF, DT_INT32); + AddInputFromArray<string>(TensorShape({1}), {"Brain"}); + AddInputFromArray<int32>(TensorShape({1}), {0}); + AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"}); + TF_ASSERT_OK(RunOpKernel()); + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_STRING, TensorShape({1})); + test::FillValues<string>(&expected, {"TensorFlow"}); + test::ExpectTensorEqual<string>(expected, params_tensor); +} + +TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) { + MakeOp(DT_BOOL_REF, DT_INT32); + AddInputFromArray<bool>(TensorShape({1}), {false}); + AddInputFromArray<int32>(TensorShape({1}), {0}); + AddInputFromArray<bool>(TensorShape({1}), {true}); + TF_ASSERT_OK(RunOpKernel()); + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_BOOL, TensorShape({1})); + test::FillValues<bool>(&expected, {true}); + test::ExpectTensorEqual<bool>(expected, params_tensor); +} + +TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2}); + AddInputFromArray<float>(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3})); + test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001, + 10002, 0, 0, 0, 777, 778, 779}); + test::ExpectTensorEqual<float>(expected, params_tensor); +} + +TEST_F(ScatterNdUpdateOpTest, Simple_Two64) { + MakeOp(DT_FLOAT_REF, DT_INT64); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray<int64>(TensorShape({3, 1}), {0, 4, 2}); + AddInputFromArray<float>(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3})); + test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001, + 10002, 0, 0, 0, 777, 778, 779}); + test::ExpectTensorEqual<float>(expected, params_tensor); +} +/*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({0}), {}); + AddInputFromArray<int32>(TensorShape({0}), {}); + AddInputFromArray<float>(TensorShape({0}), {}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Output must not have 0 elements, got shape: ")) + << s; +}*/ + +TEST_F(ScatterNdUpdateOpTest, Simple_ZeroD) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({1}), {3}); + AddInputFromArray<float>(TensorShape({1}), {101}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); + test::FillValues<float>(&expected, {0, 0, 0, 101, 0}); + test::ExpectTensorEqual<float>(expected, params_tensor); +} + +TEST_F(ScatterNdUpdateOpTest, Simple_OneD) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2}); + AddInputFromArray<float>(TensorShape({3}), {100, 101, 102}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); + test::FillValues<float>(&expected, {100, 0, 102, 0, 101}); + test::ExpectTensorEqual<float>(expected, params_tensor); +} + +TEST_F(ScatterNdUpdateOpTest, HigherRank) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6}); + AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_FLOAT, TensorShape({8})); + test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0}); + test::ExpectTensorEqual<float>(expected, params_tensor); +} + +TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 99}); + AddInputFromArray<float>(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Invalid indices: [2,0] = [99] is not in [0, 5)")) + << s; +} + +TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({1, 3, 1}), {0, 4, 99}); + AddInputFromArray<float>(TensorShape({3, 3}), + {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("The outermost dimension of updates and indices " + "must match. Got indices.shape [1,3,1], " + "updates.shape [3,3]")) + << s; +} + +TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2}); + AddInputFromArray<float>( + TensorShape({3, 4}), + {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Must have updates.shape = indices.shape[0] + " + "params_shape[IXDIM:], got")) + + << s; +} + +TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { + MakeOp(DT_FLOAT_REF, DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2}); + AddInputFromArray<float>(TensorShape({2, 3}), + {100, 101, 102, 10000, 10001, 10002}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("The outermost dimension of updates and indices " + "must match. Got ")) + << s; +} + +class ScatterNdUpdateBM : public ScatterNdUpdateOpTest { + public: + virtual void TestBody() {} + void MakeBenchmarkOp(const char* op, DataType index_type) { + TF_ASSERT_OK(NodeDefBuilder("myop", op) + .Input(FakeInput(DT_FLOAT_REF)) + .Input(FakeInput(index_type)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + } +}; + +template <typename Index> +static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) { + testing::StopTiming(); + const int kRows = 10000000 / embedding_size; + std::vector<float> values; + values.reserve(kRows); + for (int i = 0; i < kRows * embedding_size; i++) { + values.push_back(i); + } + const int kNumUpdates = 1000; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + std::vector<Index> indices; + std::vector<float> updates; + for (int i = 0; i < kNumUpdates; i++) { + indices.push_back(rnd.Uniform(kRows)); + for (int j = 0; j < embedding_size; j++) { + updates.push_back(i * 10 + j); + } + } + + ScatterNdUpdateBM bm; + bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v()); + bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values); + bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices); + bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}), + updates); + testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) * + iters); + testing::StartTiming(); + while (iters-- > 0) { + Status s = bm.RunOpKernel(); + } + testing::StopTiming(); +} + +static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) { + BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate"); +} +static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) { + BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate"); +} + +static void BM_ScatterNdAddInt32(int iters, int embedding_size) { + BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd"); +} +static void BM_ScatterNdAddInt64(int iters, int embedding_size) { + BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd"); +} + +BENCHMARK(BM_ScatterNdUpdateInt32) + ->Arg(1) + ->Arg(10) + ->Arg(64) + ->Arg(256) + ->Arg(1024); +BENCHMARK(BM_ScatterNdUpdateInt64) + ->Arg(1) + ->Arg(10) + ->Arg(64) + ->Arg(256) + ->Arg(1024); + +BENCHMARK(BM_ScatterNdAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterNdAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 6e076a092e..aff547f78a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -4382,6 +4382,83 @@ output_min: This value is copied from input_min. output_max: This value is copied from input_max. )Doc"); +REGISTER_OP("ScatterNd") + .Input("indices: Tindices") + .Input("updates: T") + .Input("shape: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .Doc( + R"doc(Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices. +This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor. + +TODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax. + +`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `shape`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `shape`. + +`updates` is Tensor of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]]. +``` + +The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../../images/ScatterNd1.png" alt> +</div> + +In Python, this scatter operation would look like this: + + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + shape = tf.constant([8]) + scatter = tf.scatter_nd(indices, updates, shape) + with tf.Session() as sess: + print sess.run(scatter) + +The resulting tensor would look like this: + + [0, 11, 0, 10, 9, 0, 0, 12] + +We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../../images/ScatterNd2.png" alt> +</div> + +In Python, this scatter operation would look like this: + + indices = tf.constant([[0], [2]]) + updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]], + [[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]]]) + shape = tf.constant([4, 4, 4]) + scatter = tf.scatter_nd(indices, updates, shape) + with tf.Session() as sess: + print sess.run(scatter) + +The resulting tensor would look like this: + + [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] + +indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. +updates: A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref. +shape: A vector. The shape of the resulting tensor. +output: A new tensor with the given shape and updates applied according to the indices.)doc"); + REGISTER_OP("FakeQuantWithMinMaxArgs") .Attr("min: float = -6.0") .Attr("max: float = 6.0") diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index b9ac8b16ff..9339b9b821 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -445,6 +445,241 @@ use_locking: If True, the operation will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); +REGISTER_OP("ScatterNdUpdate") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = true") + .Doc( + R"doc(Applies sparse `updates` to individual values or slices within a given variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this: + + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1] ,[7]]) + updates = tf.constant([9, 10, 11, 12]) + update = tf.scatter_nd_update(ref, indices, updates) + with tf.Session() as sess: + print sess.run(update) + +The resulting update to ref would look like this: + + [1, 11, 3, 10, 9, 6, 7, 12] + +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. + +ref: A mutable Tensor. Should be from a Variable node. +indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); + +REGISTER_OP("ScatterNdAdd") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc( + R"doc(Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this: + + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + add = tf.scatter_nd_add(ref, indices, updates) + with tf.Session() as sess: + print sess.run(add) + +The resulting update to ref would look like this: + + [1, 13, 3, 14, 14, 6, 7, 20] + +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. + +ref: A mutable Tensor. Should be from a Variable node. +indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); + +REGISTER_OP("ScatterNdSub") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc( + R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this: + + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + sub = tf.scatter_nd_sub(ref, indices, updates) + with tf.Session() as sess: + print sess.run(sub) + +The resulting update to ref would look like this: + + [1, -9, 3, -6, -4, 6, 7, -4] + +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. + +ref: A mutable Tensor. Should be from a Variable node. +indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); + +REGISTER_OP("ScatterNdMul") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc( + R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this: + + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + sub = tf.scatter_nd_mul(ref, indices, updates) + with tf.Session() as sess: + print sess.run(sub) + +The resulting update to ref would look like this: + + [1, 22, 3, 40, 45, 6, 7, 96] + +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. + +ref: A mutable Tensor. Should be from a Variable node. +indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); + +REGISTER_OP("ScatterNdDiv") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc( + R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. + +`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `ref`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +dimension of `ref`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +``` + +For example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this: + + ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80]) + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([2, 3, 4, 5]) + sub = tf.scatter_nd_div(ref, indices, updates) + with tf.Session() as sess: + print sess.run(sub) + +The resulting update to ref would look like this: + + [10, 5, 30, 13, 25, 60, 70, 16] + +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. + +ref: A mutable Tensor. Should be from a Variable node. +indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); + REGISTER_OP("CountUpTo") .Input("ref: Ref(T)") .Output("output: T") diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index baa48ec8a4..67a086d4fc 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -263,6 +263,13 @@ tf_py_test( ) tf_py_test( + name = "scatter_nd_ops_test", + size = "medium", + srcs = ["scatter_nd_ops_test.py"], + additional_deps = ["//tensorflow:tensorflow_py"], +) + +tf_py_test( name = "segment_reduction_ops_test", size = "small", srcs = ["segment_reduction_ops_test.py"], diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py index f3fd47381a..13b3bec3c0 100644 --- a/tensorflow/python/kernel_tests/gather_nd_op_test.py +++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py @@ -53,20 +53,20 @@ class GatherNdTest(tf.test.TestCase): gather_nd_ok_t = tf.gather_nd(params, indices_empty) gather_nd_ok_val = gather_nd_ok_t.eval() self.assertEqual([0], gather_nd_ok_t.get_shape()) - self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) indices_empty = np.empty((0, 1), dtype=np.int32) gather_nd_ok_t = tf.gather_nd(params, indices_empty) gather_nd_ok_val = gather_nd_ok_t.eval() self.assertEqual([0, 3], gather_nd_ok_t.get_shape()) - self.assertAllEqual(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val) + self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val) params_empty = np.empty((0, 3), dtype=np.float32) indices_empty = np.empty((0, 2), dtype=np.int32) gather_nd_ok_t = tf.gather_nd(params_empty, indices_empty) gather_nd_ok_val = gather_nd_ok_t.eval() self.assertEqual([0], gather_nd_ok_t.get_shape()) - self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) params_empty = np.empty((0, 3), dtype=np.float32) indices_nonempty = np.zeros((1, 2), dtype=np.int32) @@ -74,7 +74,7 @@ class GatherNdTest(tf.test.TestCase): with self.assertRaisesOpError( r"Requested more than 0 entries, but params is empty."): gather_nd_break_t.eval() - self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) def testIndexScalar(self): with self.test_session(use_gpu=self.use_gpu): @@ -184,11 +184,11 @@ class GatherNdTest(tf.test.TestCase): indices = tf.placeholder(tf.int32) gather_nd_t = tf.gather_nd(params, indices) shape = gather_nd_t.get_shape() - self.assertEqual(shape.ndims, None) - self.assertEqual(shape[0].value, None) + self.assertEqual(None, shape.ndims) + self.assertEqual(None, shape[0].value) def testBadIndices(self): - with self.test_session(use_gpu=False): + with self.test_session(): params = [0, 1, 2] indices = [[[0], [7]]] # Make this one higher rank gather_nd = tf.gather_nd(params, indices) @@ -198,7 +198,7 @@ class GatherNdTest(tf.test.TestCase): gather_nd.eval() def testBadIndicesWithSlices(self): - with self.test_session(use_gpu=False): + with self.test_session(): params = [[0, 1, 2]] indices = [[[0], [0], [1]]] # Make this one higher rank gather_nd = tf.gather_nd(params, indices) @@ -207,6 +207,62 @@ class GatherNdTest(tf.test.TestCase): r"\(shape: \[1,3\]\)"): gather_nd.eval() + def testGradientsRank2Elements(self): + indices = tf.constant([[0, 0], [1, 1]], dtype=tf.int32) + inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float64) + outputs = tf.gather_nd(inputs, indices) + + grad_vals = tf.constant([1, 2], dtype=tf.float64) + grads = tf.gradients([outputs], [inputs], [grad_vals])[0] + expected_grads = np.array([[1, 0], [0, 2]], dtype=np.float64) + with self.test_session(): + assert np.array_equal(expected_grads, grads.eval()) + + def testGradientsRank2Slices(self): + indices = tf.constant([[1], [0]], dtype=tf.int32) + inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float64) + outputs = tf.gather_nd(inputs, indices) + + grad_vals = tf.constant([[1, 2], [3, 4]], dtype=tf.float64) + grads = tf.gradients([outputs], [inputs], [grad_vals])[0] + expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64) + with self.test_session(): + self.assertAllEqual(expected_grads, grads.eval()) + + def testGradientsRank3Elements(self): + indices = tf.constant([[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=tf.int32) + inputs = tf.constant([[[1, 3], [5, 7]], [[2, 4], [6, 8]]], dtype=tf.float64) + outputs = tf.gather_nd(inputs, indices) + + grad_vals = tf.constant( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=tf.float64) + grads = tf.gradients([outputs], [inputs], [grad_vals])[0] + expected_grads = np.array( + [[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64) + with self.test_session(): + self.assertAllEqual(expected_grads, grads.eval()) + + def testGradientsRank2SlicesWithEmptySpace(self): + indices = tf.constant([[2], [0], [5]], dtype=tf.int32) + inputs = tf.constant( + [[1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9]], + dtype=tf.float64) + outputs = tf.gather_nd(inputs, indices) + grad_vals = tf.constant( + [[1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2], + [3, 3, 3, 3, 3, 3, 3, 3, 3]], + dtype=tf.float64) + grads = tf.gradients([outputs], [inputs], [grad_vals])[0] + expected_grads = np.array( + [[2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]], + dtype=np.float64) + with self.test_session(): + self.assertAllEqual(expected_grads, grads.eval()) + class GatherNdGpuTest(GatherNdTest): use_gpu = True diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py new file mode 100644 index 0000000000..2ad7ae810c --- /dev/null +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -0,0 +1,383 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.scatter_nd.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from operator import add +from operator import mul +from operator import sub + +import numpy as np +import tensorflow as tf + + +def _AsType(v, vtype): + return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v) + + +def _FlatInnerDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape([functools.reduce(mul, shape[:-ndims + 1], 1)] + shape[ + -ndims + 1:]) + + +def _FlatOuterDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape(shape[:ndims - 1] + + [functools.reduce(mul, shape[ndims - 1:], 1)]) + + +def _NumpyScatterNd(ref, indices, updates, op): + ixdim = indices.shape[-1] + num_updates = indices.size / ixdim + total_nd = len(ref.shape) + slice_size = 1 + for i in range(ixdim, total_nd): + slice_size *= ref.shape[i] + flat_indices = _FlatInnerDims(indices) + flat_updates = updates.reshape((num_updates, slice_size)) + output_flat = _FlatOuterDims(ref, ixdim + 1) + for ix_updates, ix_output in enumerate(flat_indices): + ix_output = tuple(ix_output) + output_flat[ix_output] = op(output_flat[ix_output], + flat_updates[ix_updates]) + return output_flat.reshape(ref.shape) + + +def _NumpyUpdate(ref, indices, updates): + return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) + + +def _NumpyAdd(ref, indices, updates): + return _NumpyScatterNd(ref, indices, updates, add) + + +def _NumpySub(ref, indices, updates): + return _NumpyScatterNd(ref, indices, updates, sub) + + +def _NumpyMul(ref, indices, updates): + return _NumpyScatterNd(ref, indices, updates, mul) + + +def _NumpyDiv(ref, indices, updates): + return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u) + + +class ScatterTest(tf.test.TestCase): + + def _VariableRankTest(self, + np_scatter, + tf_scatter, + vtype, + itype, + use_gpu, + repeat_indices=False): + np.random.seed(8) + ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)] + indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)] + with self.test_session(use_gpu=use_gpu): + for ref_shape, indices_shape in zip(ref_shapes, indices_shapes): + num_updates = indices_shape[0] + ixdim = indices_shape[-1] + + indexable_area_shape = () + for i in range(ixdim): + indexable_area_shape += (ref_shape[i],) + all_indices = [ + list(coord) + for coord, _ in np.ndenumerate( + np.empty(indexable_area_shape, vtype)) + ] + np.random.shuffle(all_indices) + indices = np.array(all_indices[:num_updates]) + + if num_updates > 1 and repeat_indices: + indices = indices[:num_updates // 2] + for _ in range(num_updates - num_updates // 2): + indices = np.append( + indices, [indices[np.random.randint(num_updates // 2)]], axis=0) + np.random.shuffle(indices) + indices = _AsType(indices[:num_updates], itype) + + updates_shape = (num_updates,) + for i in range(ixdim, len(ref_shape)): + updates_shape += (ref_shape[i],) + updates = _AsType(np.random.randn(*(updates_shape)), vtype) + ref = _AsType(np.random.randn(*(ref_shape)), vtype) + + # Scatter via numpy + new = ref.copy() + np_scatter(new, indices, updates) + # Scatter via tensorflow + ref_var = tf.Variable(ref) + ref_var.initializer.run() + tf_scatter(ref_var, indices, updates).eval() + # Compare + self.assertAllClose(new, ref_var.eval()) + + def _VariableRankTests(self, np_scatter, tf_scatter): + for vtype in (np.float32, np.float64): + for itype in (np.int32, np.int64): + for use_gpu in (False, True): + self._VariableRankTest(np_scatter, tf_scatter, vtype, itype, use_gpu) + + def testVariableRankUpdate(self): + self._VariableRankTests(_NumpyUpdate, tf.scatter_nd_update) + + def testVariableRankAdd(self): + self._VariableRankTests(_NumpyAdd, tf.scatter_nd_add) + + def testVariableRankSub(self): + self._VariableRankTests(_NumpySub, tf.scatter_nd_sub) + + def testVariableRankMul(self): + self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul) + + def testVariableRankDiv(self): + self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div) + + def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter): + for vtype in (np.float32, np.float64): + for itype in (np.int32, np.int64): + for use_gpu in (False, True): + self._VariableRankTest( + np_scatter, + tf_scatter, + vtype, + itype, + use_gpu, + repeat_indices=True) + + def testScatterRepeatIndices(self): + """This tests scatter_add using indices that repeat.""" + self._ScatterRepeatIndicesTest(_NumpyAdd, tf.scatter_nd_add) + self._ScatterRepeatIndicesTest(_NumpySub, tf.scatter_nd_sub) + self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul) + self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div) + + def testBooleanScatterUpdate(self): + with self.test_session(use_gpu=False) as session: + var = tf.Variable([True, False]) + update0 = tf.scatter_nd_update(var, [[1]], [True]) + update1 = tf.scatter_nd_update( + var, tf.constant( + [[0]], dtype=tf.int64), [False]) + var.initializer.run() + + session.run([update0, update1]) + + self.assertAllEqual([False, True], var.eval()) + + def testScatterOutOfRangeCpu(self): + for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul, + tf.scatter_nd_div, tf.scatter_nd_update): + params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) + updates = np.array([-3, -4, -5]).astype(np.float32) + with self.test_session(use_gpu=False): + ref = tf.Variable(params) + ref.initializer.run() + + # Indices all in range, no problem. + indices = np.array([[2], [0], [5]]) + op(ref, indices, updates).eval() + + # Test some out of range errors. + indices = np.array([[-1], [0], [5]]) + with self.assertRaisesOpError( + r"Invalid indices: \[0,0\] = \[-1\] is not in \[0, 6\)"): + op(ref, indices, updates).eval() + + indices = np.array([[2], [0], [6]]) + with self.assertRaisesOpError( + r"Invalid indices: \[2,0\] = \[6\] is not in \[0, 6\)"): + op(ref, indices, updates).eval() + + def testRank3ValidShape(self): + indices = tf.zeros([2, 2, 2], tf.int32) + updates = tf.zeros([2, 2, 2], tf.int32) + shape = np.array([2, 2, 2]) + self.assertAllEqual( + tf.scatter_nd(indices, updates, shape).get_shape().as_list(), shape) + + ref = tf.Variable(tf.zeros(shape, tf.int32)) + self.assertAllEqual( + tf.scatter_nd_update(ref, indices, updates).get_shape().as_list(), + shape) + + def testUndefinedIndicesShape(self): + indices = tf.placeholder(tf.int32, shape=None) + updates = tf.placeholder(tf.int32, shape=[2, 2, 2]) + shape = tf.constant([2, 2, 2], tf.int32) + tf.scatter_nd(indices, updates, shape) + + def testUndefinedUpdatesShape(self): + indices = tf.placeholder(tf.int32, shape=[2, 2, 2]) + updates = tf.placeholder(tf.int32, shape=None) + shape = tf.constant([2, 2, 2], tf.int32) + tf.scatter_nd(indices, updates, shape) + + def testUndefinedOutputShape(self): + indices = tf.placeholder(tf.int32, shape=[2, 2, 2]) + updates = tf.placeholder(tf.int32, shape=[2, 2, 2]) + shape = tf.placeholder(tf.int32, shape=[None]) + tf.scatter_nd(indices, updates, shape) + + def testEmptyoutputShape1(self): + indices = tf.zeros([2, 2, 2], tf.int32) + updates = tf.zeros([2, 2, 2], tf.int32) + shape = tf.constant([0, 3, 2], tf.int32) + + with self.assertRaisesWithPredicateMatch( + ValueError, "Indices and updates specified for empty output shape"): + tf.scatter_nd(indices, updates, shape) + + def testEmptyoutputShape2(self): + indices = tf.placeholder(tf.int32, shape=None) + updates = tf.placeholder(tf.int32, shape=None) + shape = tf.constant([0, 3, 2], tf.int32) + + with self.test_session(): + #with self.assertRaisesOpError( + # r"Indices and updates specified for empty output shape"): + tf.scatter_nd(indices, updates, shape).eval(feed_dict={ + indices: np.zeros( + [2, 2, 2], dtype=np.int32), + updates: np.zeros( + [2, 2, 2], dtype=np.int32) + }) + + def testEmptyoutputShape3(self): + indices = tf.zeros([0], tf.int32) + updates = tf.zeros([0], tf.int32) + shape = tf.constant([0], tf.int32) + scatter = tf.scatter_nd(indices, updates, shape) + + with self.test_session(): + self.assertEqual(scatter.eval().size, 0) + + def testRank3InvalidShape1(self): + indices = tf.zeros([3, 2, 2], tf.int32) + updates = tf.zeros([2, 2, 2], tf.int32) + shape = np.array([2, 2, 2]) + with self.assertRaisesWithPredicateMatch( + ValueError, "The outer \d+ dimensions of indices\.shape="): + tf.scatter_nd(indices, updates, shape) + + ref = tf.Variable(tf.zeros(shape, tf.int32)) + with self.assertRaisesWithPredicateMatch( + ValueError, "The outer \d+ dimensions of indices\.shape="): + tf.scatter_nd_update(ref, indices, updates) + + def testRank3InvalidShape2(self): + indices = tf.zeros([2, 2, 1], tf.int32) + updates = tf.zeros([2, 2], tf.int32) + shape = np.array([2, 2, 2]) + with self.assertRaisesWithPredicateMatch( + ValueError, "The inner \d+ dimensions of output.shape="): + tf.scatter_nd(indices, updates, shape) + + ref = tf.Variable(tf.zeros(shape, tf.int32)) + with self.assertRaisesWithPredicateMatch( + ValueError, "The inner \d+ dimensions of ref\.shape="): + tf.scatter_nd_update(ref, indices, updates) + + def testGradientsRank2ElementUpdate(self): + indices = tf.constant([[0, 0], [1, 1]], dtype=tf.int32) + updates = tf.constant([1, 4], dtype=tf.float64) + shape = tf.constant([2, 2], dtype=tf.int32) + outputs = tf.scatter_nd(indices, updates, shape) + + grad_vals = tf.constant([[1, 2], [3, 4]], dtype=tf.float64) + grads = tf.gradients([outputs], [updates], [grad_vals])[0] + expected_grads = np.array([1, 4], dtype=np.float64) + with self.test_session(): + self.assertAllEqual(expected_grads, grads.eval()) + + def testGradientsRank2SliceUpdate(self): + indices = tf.constant([[1], [0]], dtype=tf.int32) + updates = tf.constant([[3, 4], [1, 2]], dtype=tf.float64) + shape = tf.constant([2, 2], dtype=tf.int32) + outputs = tf.scatter_nd(indices, updates, shape) + + grad_vals = tf.constant([[3, 4], [1, 2]], dtype=tf.float64) + grads = tf.gradients([outputs], [updates], [grad_vals])[0] + expected_grads = np.array([[1, 2], [3, 4]], dtype=np.float64) + with self.test_session(): + self.assertAllEqual(expected_grads, grads.eval()) + + def testGradientsRank3SliceUpdate(self): + indices = tf.constant([[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=tf.int32) + updates = tf.constant( + [[[5, 7], [2, 4]], [[1, 3], [6, 8]]], dtype=tf.float64) + shape = tf.constant([2, 2, 2], dtype=tf.int32) + outputs = tf.scatter_nd(indices, updates, shape) + + grad_vals = tf.constant( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=tf.float64) + grads = tf.gradients([outputs], [updates], [grad_vals])[0] + expected_grads = np.array( + [[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64) + with self.test_session(): + self.assertAllEqual(expected_grads, grads.eval()) + + def testConcurrentUpdates(self): + num_updates = 10000 + update_values = np.random.rand(num_updates) + ref = tf.Variable(np.zeros([2, 2]), dtype=tf.float64) + indices = tf.constant([[0, 1]] * num_updates, dtype=tf.int32) + updates = tf.constant(update_values, dtype=tf.float64) + + exepected_result = np.zeros([2, 2], dtype=np.float64) + exepected_result[0, 1] = np.sum(update_values) + + scatter = tf.scatter_nd_add(ref, indices, updates) + init = tf.initialize_all_variables() + + with tf.Session() as sess: + sess.run(init) + result = sess.run(scatter) + assert np.allclose(result, exepected_result) + + # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU. + def _disabledTestScatterOutOfRangeGpu(self): + if not tf.test.IsBuiltWithCuda(): + return + for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul, + tf.scatter_nd_div, tf.scatter_nd_update): + params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) + updates = np.array([-3, -4, -5]).astype(np.float32) + # With GPU, the code ignores indices that are out of range. + # We don't test the implementation; just test there's no failures. + with self.test_session(force_gpu=True): + ref = tf.Variable(params) + ref.initializer.run() + + # Indices all in range, no problem. + indices = np.array([2, 0, 5]) + op(ref, indices, updates).eval() + + # Indicies out of range should not fail. + indices = np.array([-1, 0, 5]) + op(ref, indices, updates).eval() + indices = np.array([2, 0, 6]) + op(ref, indices, updates).eval() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 40f4ceb69d..b97fbff644 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -301,8 +301,12 @@ def _GatherGrad(op, grad): @ops.RegisterGradient("GatherNd") -def _GatherNdGrad(unused_op, unused_grad): - raise NotImplementedError("Gradient for gather_nd is not implemented.") +def _GatherNdGrad(op, grad): + ref = op.inputs[0] + ref_shape = array_ops.shape(ref) + indices = op.inputs[1] + ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) + return [ref_grad, None] @ops.RegisterGradient("CheckNumerics") @@ -566,3 +570,10 @@ def _ExtractImagePatchesGrad(op, grad): grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return [grad_out] + + +@ops.RegisterGradient("ScatterNd") +def _ScatterNdGrad(op, grad): + indices = op.inputs[0] + updates_grad = array_ops.gather_nd(grad, indices) + return [None, updates_grad, None] diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index e1f4ce2172..01e9fd9aa7 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -71,6 +71,7 @@ or join multiple tensors together. @@gather @@gather_nd @@unique_with_counts +@@scatter_nd @@dynamic_partition @@dynamic_stitch @@boolean_mask @@ -2461,3 +2462,50 @@ def _QuantizedReshapeShape(op): ops.RegisterShape("QuantizeV2")(None) ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")(None) ops.RegisterShape("QuantizedConcat")(None) + + +@ops.RegisterShape("ScatterNd") +def _ScatterNdShape(op): + """Shape function for the ScatterNd op. + + The shape of the ouput is defined as a parameter on the Operation. + + Args: + op: A ScatterNd Operation. + + Returns: + A single-element list containing the shape of the output. + + Raises: + ValueError: if the arguments have invalid rank + """ + indices_shape = op.inputs[0].get_shape() + updates_shape = op.inputs[1].get_shape() + shape = op.inputs[2] + + output_shape = tensor_util.constant_value_as_shape(shape) + if output_shape.num_elements() == 0 and not ( + indices_shape.num_elements() in + (None, 0) and updates_shape.num_elements() in (None, 0)): + raise ValueError("Indices and updates specified for empty output shape") + + if indices_shape.ndims is not None and shape is not None: + outer_dims = len(indices_shape) - 1 + ixdim = indices_shape[-1].value or 0 + + if not indices_shape[:outer_dims].is_compatible_with( + updates_shape[:outer_dims]): + raise ValueError("The outer %d dimensions of indices.shape=%s must " \ + "match the outer %d dimensions of updates.shape=%s" % ( + outer_dims, indices_shape, outer_dims, + updates_shape)) + if output_shape.ndims is not None: + if not output_shape[ixdim:].is_compatible_with(updates_shape[ + outer_dims:]): + raise ValueError("The inner %d dimensions of output.shape=%s must " \ + "match the inner %d dimensions of updates.shape=%s" % ( + len(output_shape)-ixdim, output_shape, + len(updates_shape)-outer_dims, updates_shape)) + + return [output_shape] + return [None] diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 9267b8ef2e..847d1b99c8 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -67,6 +67,11 @@ from tensorflow.python.ops.state_ops import scatter_div from tensorflow.python.ops.state_ops import scatter_mul from tensorflow.python.ops.state_ops import scatter_sub from tensorflow.python.ops.state_ops import scatter_update +from tensorflow.python.ops.state_ops import scatter_nd_add +from tensorflow.python.ops.state_ops import scatter_nd_sub +from tensorflow.python.ops.state_ops import scatter_nd_mul +from tensorflow.python.ops.state_ops import scatter_nd_div +from tensorflow.python.ops.state_ops import scatter_nd_update from tensorflow.python.ops.string_ops import * from tensorflow.python.ops.template import * from tensorflow.python.ops.tensor_array_ops import * diff --git a/tensorflow/python/ops/state_grad.py b/tensorflow/python/ops/state_grad.py index 871ce780c5..314f9f0c1a 100644 --- a/tensorflow/python/ops/state_grad.py +++ b/tensorflow/python/ops/state_grad.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.ops import state_ops + # TODO(b/31222613): These ops may be differentiable, and there may be # latent bugs here. @@ -43,3 +43,14 @@ ops.NotDifferentiable("ScatterMul") ops.NotDifferentiable("ScatterDiv") + + +ops.NotDifferentiable("ScatterNdUpdate") + +ops.NotDifferentiable("ScatterNdAdd") + +ops.NotDifferentiable("ScatterNdSub") + +ops.NotDifferentiable("ScatterNdMul") + +ops.NotDifferentiable("ScatterNdDiv") diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 636acc3e2a..eede28cdc9 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -95,6 +95,11 @@ automatically by the optimizers in most cases. @@scatter_sub @@scatter_mul @@scatter_div +@@scatter_nd_update +@@scatter_nd_add +@@scatter_nd_sub +@@scatter_nd_mul +@@scatter_nd_div @@sparse_mask @@IndexedSlices @@ -209,3 +214,34 @@ ops.RegisterShape("ScatterDiv")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ScatterMul")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ScatterSub")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ScatterUpdate")(common_shapes.call_cpp_shape_fn) + + +@ops.RegisterShape("ScatterNdAdd") +@ops.RegisterShape("ScatterNdSub") +@ops.RegisterShape("ScatterNdMul") +@ops.RegisterShape("ScatterNdDiv") +@ops.RegisterShape("ScatterNdUpdate") +def scatter_nd_update_shape(op): + """Shape function for the ScatterNd update ops.""" + ref_shape = op.inputs[0].get_shape() + indices_shape = op.inputs[1].get_shape() + updates_shape = op.inputs[2].get_shape() + + if indices_shape.ndims is not None and ref_shape.ndims is not None: + outer_dims = len(indices_shape) - 1 + ixdim = indices_shape[-1].value or 0 + + if not indices_shape[:outer_dims].is_compatible_with( + updates_shape[:outer_dims]): + raise ValueError("The outer %d dimensions of indices.shape=%s must " \ + "match the outer %d dimensions of updates.shape=%s" % ( + outer_dims, indices_shape, outer_dims, + updates_shape)) + + if not ref_shape[ixdim:].is_compatible_with(updates_shape[outer_dims:]): + raise ValueError("The inner %d dimensions of ref.shape=%s must match " \ + "the inner %d dimensions of updates.shape=%s" % ( + len(ref_shape)-ixdim, ref_shape, + len(updates_shape)-outer_dims, updates_shape)) + + return [ref_shape] |