diff options
author | 2017-06-27 15:55:10 -0700 | |
---|---|---|
committer | 2017-06-27 16:06:46 -0700 | |
commit | 63ac86ae19933bc9386233cb899d653f388f1bef (patch) | |
tree | 0a73c888050d780e1d5d2014746c9a84d0b795f4 /tensorflow | |
parent | 494f5c0c95d16577844cc430b59e611e88750c29 (diff) |
Fix bugs in ScatterNd and add ScatterNdNonAliasingAdd.
tf.scatter_nd_non_aliasing_add acts similarly to tf.scatter_nd_add but
works on non-ref objects (i.e., Tensors -- not Variables). This means
it has a gradient with respect to the primary input as well as the
updates. It does its best to avoid making extra copies of the input.
PiperOrigin-RevId: 160339328
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/cc/gradients/array_grad.cc | 11 | ||||
-rw-r--r-- | tensorflow/cc/gradients/array_grad_test.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 62 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 3 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 22 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.cc | 149 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 56 | ||||
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 67 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/scatter_nd_ops_test.py | 243 | ||||
-rw-r--r-- | tensorflow/python/ops/array_grad.py | 7 |
10 files changed, 431 insertions, 211 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 37f07e71a0..e69db10580 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -247,6 +247,17 @@ Status ScatterNdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad); +Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto indices = op.input(1); + grad_outputs->push_back(Identity(scope, grad_inputs[0])); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); + return scope.status(); +} +REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad); + Status PadGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 5798b5b509..1777e18145 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -233,6 +233,28 @@ TEST_F(ArrayGradTest, ScatterNdGrad_SliceIndexing) { RunTest(updates, updates_shape, y, y_shape); } +TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SimpleIndexing) { + TensorShape updates_shape({4}); + TensorShape input_shape({8}); + auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape)); + auto updates = + Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); + auto indices = Const(scope_, {{4}, {3}, {1}, {7}}); + auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates); + RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); +} + +TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SliceIndexing) { + TensorShape updates_shape({2, 4, 4}); + TensorShape input_shape({4, 4, 4}); + auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape)); + auto updates = + Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); + auto indices = Const(scope_, {{0}, {2}}); + auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates); + RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); +} + TEST_F(ArrayGradTest, PadGrad) { TensorShape x_shape({2, 3}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 035bceb640..dc2db3d395 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -922,5 +922,67 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, return Status::OK(); } +Status ScatterNdUpdateShape(InferenceContext* c) { + ShapeHandle input_shape = c->input(0); + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); + ShapeHandle updates_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); + + if (c->Value(c->NumElements(input_shape)) == 0 && + (c->Value(c->NumElements(indices_shape)) > 0 || + c->Value(c->NumElements(updates_shape)) > 0)) { + return errors::InvalidArgument( + "Indices and updates specified for empty output shape"); + } + + if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { + const int64 num_outer_dims = c->Rank(indices_shape) - 1; + const DimensionHandle index_size = c->Dim(indices_shape, -1); + + // We can only do more validation if the last dimension of indices + // is a known value. + if (c->ValueKnown(index_size)) { + const int64 ix = c->Value(index_size); + ShapeHandle unused; + ShapeHandle prefix_indices; + TF_RETURN_IF_ERROR( + c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices)); + ShapeHandle prefix_updates; + TF_RETURN_IF_ERROR( + c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); + + Status s = c->Merge(prefix_indices, prefix_updates, &unused); + if (!s.ok()) { + return errors::InvalidArgument( + "The outer ", num_outer_dims, + " dimensions of indices.shape=", c->DebugString(indices_shape), + " must match the outer ", num_outer_dims, + " dimensions of updates.shape=", c->DebugString(updates_shape), + ": ", s.error_message()); + } + + ShapeHandle input_suffix; + TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix)); + ShapeHandle suffix_updates; + TF_RETURN_IF_ERROR( + c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); + s = c->Merge(input_suffix, suffix_updates, &unused); + if (!s.ok()) { + return errors::InvalidArgument( + "The inner ", c->Rank(input_shape) - ix, + " dimensions of input.shape=", c->DebugString(input_shape), + " must match the inner ", c->Rank(updates_shape) - num_outer_dims, + " dimensions of updates.shape=", c->DebugString(updates_shape), + ": ", s.error_message()); + } + } + } + + c->set_output(0, input_shape); + return Status::OK(); +} + } // namespace shape_inference + } // namespace tensorflow diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index dc99e48adb..aafd9eff27 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -207,6 +207,9 @@ Status RandomShape(shape_inference::InferenceContext* c); Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, ShapeHandle values_shape, ShapeHandle shape_shape); +// Shape function for ScatterNd update/add/sub/... operations. +Status ScatterNdUpdateShape(InferenceContext* c); + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8b7c269a11..41995a2ff5 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3475,8 +3475,26 @@ tf_kernel_library( tf_kernel_library( name = "scatter_nd_op", - prefix = "scatter_nd_op", - deps = STATE_DEPS, + srcs = [ + "scatter_nd_op.cc", + "scatter_nd_op_cpu_impl_0.cc", + "scatter_nd_op_cpu_impl_1.cc", + "scatter_nd_op_cpu_impl_2.cc", + "scatter_nd_op_cpu_impl_3.cc", + "scatter_nd_op_cpu_impl_4.cc", + "scatter_nd_op_cpu_impl_5.cc", + ], + hdrs = [ + "dense_update_ops.h", + "scatter_nd_op.h", + "scatter_nd_op_cpu_impl.h", + ], + gpu_srcs = [ + "dense_update_ops.h", + "scatter_nd_op.h", + "scatter_nd_op_gpu.cu.cc", + ], + deps = STATE_DEPS + [":dense_update_ops"], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 48565d8cb9..5f96dfa0f3 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -16,12 +16,17 @@ limitations under the License. // See docs in ../ops/state_ops.cc. #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + #include "tensorflow/core/kernels/scatter_nd_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/dense_update_ops.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -29,7 +34,7 @@ limitations under the License. #ifdef TENSORFLOW_USE_SYCL #include "tensorflow/core/common_runtime/sycl/sycl_util.h" -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace tensorflow { @@ -37,7 +42,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Check whether updates.shape = indices.shape[:batch_dim] + // params_shape[slice_dim:] @@ -91,11 +96,13 @@ static void PrepareAndValidateInputs(OpKernelContext* c, errors::InvalidArgument("Output must be at least 1-D, ", "got shape: ", params_shape.DebugString())); - OP_REQUIRES(c, - params_shape.num_elements() >= 0 || - (indices.NumElements() == 0 && updates.NumElements() == 0), - errors::InvalidArgument( - "Indices and updates specified for empty output", " shape")); + OP_REQUIRES( + c, + params_shape.num_elements() > 0 || + (indices.NumElements() == 0 && updates.NumElements() == 0), + errors::InvalidArgument( + "Indices and updates specified for empty output. indices shape: ", + indices.shape().DebugString())); OP_REQUIRES(c, updates.dim_size(0) == indices.dim_size(0), errors::InvalidArgument( @@ -147,9 +154,9 @@ static void PrepareAndValidateInputs(OpKernelContext* c, template <typename Device, typename Index> class IndexFlattener { -public: - inline typename TTypes<Index, 2>::ConstTensor - operator()(OpKernelContext*, const Tensor& indices) { + public: + inline typename TTypes<Index, 2>::ConstTensor operator()( + OpKernelContext*, const Tensor& indices) { return indices.flat_inner_dims<Index>(); } }; @@ -157,12 +164,12 @@ public: #ifdef TENSORFLOW_USE_SYCL template <typename Index> class IndexFlattener<SYCLDevice, Index> { -public: + public: IndexFlattener() { indices_host_ = nullptr; } ~IndexFlattener() { delete[] indices_host_; } - inline typename TTypes<Index, 2>::ConstTensor - operator()(OpKernelContext* c, const Tensor& indices) { + inline typename TTypes<Index, 2>::ConstTensor operator()( + OpKernelContext* c, const Tensor& indices) { size_t num_indices = indices.NumElements(); indices_host_ = new Index[num_indices]; auto device = c->eigen_sycl_device(); @@ -170,11 +177,11 @@ public: auto src_ptr = GetBase(&indices); device.memcpyDeviceToHost(indices_host_, static_cast<const Index*>(src_ptr), size); - return typename TTypes<Index, 2>::ConstTensor(indices_host_, - indices.shape().AsEigenDSizes<2>()); + return typename TTypes<Index, 2>::ConstTensor( + indices_host_, indices.shape().AsEigenDSizes<2>()); } -private: + private: Index* indices_host_; }; #endif @@ -213,6 +220,9 @@ class ScatterNdOp : public OpKernel { Tensor* out = nullptr; OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); + + if (shape.num_elements() == 0) return; + functor::SetZeroFunctor<Device, T> fill; fill(c->eigen_device<Device>(), out->flat<T>()); auto output_matrix = out->template shaped<T, 2>( @@ -271,12 +281,19 @@ class ScatterNdUpdateOp : public OpKernel { const DataType dt = DataTypeToEnum<T>::v(); const DataType dt_ref = DataTypeToEnum<T>::ref(); const DataType index_t = DataTypeToEnum<Index>::v(); - OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref})); - OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); + if (IsRefType(c->input_type(0))) { + OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref})); + OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); + } else { + OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt})); + use_exclusive_lock_ = false; + } } void Compute(OpKernelContext* c) override { if (use_exclusive_lock_) { + // If we're here, it means the input type is a ref. + DCHECK(IsRefType(c->input_dtype(0))); // Hold mutex while we apply updates mutex_lock l(*c->input_ref_mutex(0)); DoCompute(c); @@ -289,20 +306,41 @@ class ScatterNdUpdateOp : public OpKernel { bool use_exclusive_lock_; void DoCompute(OpKernelContext* c) { - Tensor params = c->mutable_input(0, use_exclusive_lock_); const Tensor& indices = c->input(1); const Tensor& updates = c->input(2); - const TensorShape& params_shape(params.shape()); int64 slice_dim; Index num_updates; Index slice_size; - OP_REQUIRES(c, params.IsInitialized(), - errors::FailedPrecondition("Null ref for params")); + Tensor params; + TensorShape params_shape; + + if (IsRefType(c->input_dtype(0))) { + params = c->mutable_input(0, use_exclusive_lock_); + params_shape = params.shape(); + c->forward_ref_input_to_ref_output(0, 0); + OP_REQUIRES(c, params.IsInitialized(), + errors::FailedPrecondition("Null ref for params")); + } else { + Tensor* params_ptr; + params_shape = c->input(0).shape(); + if (!c->forward_input_to_output_with_shape(0, 0, params_shape, + ¶ms_ptr)) { + // We weren't able to forward the input to output, so just + // allocate a new output tensor and copy the values over. + OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, ¶ms_ptr)); + params = *params_ptr; + functor::DenseUpdate<Device, T, ASSIGN> copy; + const Tensor& input_copy = c->input(0); + copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>()); + } + } + PrepareAndValidateInputs<Index>(c, params_shape, indices, updates, &slice_dim, &num_updates, &slice_size); if (!c->status().ok()) return; + if (params_shape.num_elements() == 0) return; IndexFlattener<Device, Index> index_flattener; auto indices_flat = index_flattener(c, indices); @@ -310,7 +348,6 @@ class ScatterNdUpdateOp : public OpKernel { auto params_matrix = params.template shaped<T, 2>( {params_shape.num_elements() / slice_size, slice_size}); Index bad_i = -1; - c->forward_ref_input_to_ref_output(0, 0); switch (slice_dim) { #define PARAMS_CASE(IXDIM) \ @@ -376,10 +413,12 @@ class ScatterNdUpdateOp : public OpKernel { REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op) -#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \ - REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ - scatter_nd_op::UpdateOp::ADD); \ - REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \ +#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ + scatter_nd_op::UpdateOp::ADD); \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \ + scatter_nd_op::UpdateOp::ADD); \ + REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \ scatter_nd_op::UpdateOp::SUB); // TODO(simister): Find a way to reduce amount of templated generated code // to reduce build size, then re-enable these additional operations. @@ -421,9 +460,31 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_GPU); - TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_GPU); +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \ + REGISTER_SCATTER_ND_ADD_SUB(type, SYCL); + +#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ + REGISTER_SCATTER_ND_UPDATE(type, SYCL); + +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL); +#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL +#undef REGISTER_SCATTER_ND_UPDATE_SYCL +#endif // TENSORFLOW_USE_SYCL + +#undef REGISTER_SCATTER_ND_ADD +#undef REGISTER_SCATTER_ND_ADD_SUB +#undef REGISTER_SCATTER_ND_ADD_SUB_CPU +#undef REGISTER_SCATTER_ND_ADD_SUB_GPU +#undef REGISTER_SCATTER_ND_UPDATE +#undef REGISTER_SCATTER_ND_UPDATE_CPU +#undef REGISTER_SCATTER_ND_UPDATE_GPU +#undef REGISTER_SCATTER_ND_KERNEL +#undef REGISTER_SCATTER_ND_KERNEL_INDEX + // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ @@ -458,31 +519,19 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS); #undef DECLARE_GPU_SPECS #undef DECLARE_GPU_SPECS_INDEX #undef DECLARE_GPU_SPECS_INDEX_OP -} // namespace functor - -#endif // GOOGLE_CUDA -#ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \ - REGISTER_SCATTER_ND_ADD_SUB(type, SYCL); +#define REGISTER_GPU_KERNELS(type) \ + template <> \ + void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \ + const GPUDevice& d, typename TTypes<type>::Flat lhs, \ + typename TTypes<type>::ConstFlat rhs); \ + extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; -#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ - REGISTER_SCATTER_ND_UPDATE(type, SYCL); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL); -#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL -#undef REGISTER_SCATTER_ND_UPDATE_SYCL -#endif // TENSORFLOW_USE_SYCL +} // namespace functor -#undef REGISTER_SCATTER_ND_ADD -#undef REGISTER_SCATTER_ND_ADD_SUB -#undef REGISTER_SCATTER_ND_ADD_SUB_CPU -#undef REGISTER_SCATTER_ND_ADD_SUB_GPU -#undef REGISTER_SCATTER_ND_UPDATE -#undef REGISTER_SCATTER_ND_UPDATE_CPU -#undef REGISTER_SCATTER_ND_UPDATE_GPU -#undef REGISTER_SCATTER_ND_KERNEL -#undef REGISTER_SCATTER_ND_KERNEL_INDEX +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index f27f8968d3..c2b81615bd 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -4912,6 +4912,62 @@ output: A new tensor with the given shape and updates applied according to the indices. )doc"); +REGISTER_OP("ScatterNdNonAliasingAdd") + .Input("input: T") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::ScatterNdUpdateShape) + .Doc(R"doc( +Applies sparse addition to `input` using individual values or slices +from `updates` according to indices `indices`. The updates are non-aliasing: +`input` is only modified in-place if no other operations will use it. +Otherwise, a copy of `input` is made. This operation has a gradient with +respect to both `input` and `updates`. + +`input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +`indices` must be integer tensor, containing indices into `input`. +It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +The innermost dimension of `indices` (with length `K`) corresponds to +indices into elements (if `K = P`) or `(P-K)`-dimensional slices +(if `K < P`) along the `K`th dimension of `input`. + +`updates` is `Tensor` of rank `Q-1+P-K` with shape: + +``` +[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]]. +``` + +For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 +elements. In Python, that addition would look like this: + + input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + output = tf.scatter_nd_non_aliasing_add(input, indices, updates) + with tf.Session() as sess: + print(sess.run(output)) + +The resulting value `output` would look like this: + + [1, 13, 3, 14, 14, 6, 7, 20] + +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to +slices. + +input: A Tensor. +indices: A Tensor. Must be one of the following types: `int32`, `int64`. + A tensor of indices into `input`. +updates: A Tensor. Must have the same type as ref. A tensor of updated values + to add to `input`. +output: A `Tensor` with the same shape as `input`, containing values of `input` + updated with `updates`. +)doc"); + REGISTER_OP("FakeQuantWithMinMaxArgs") .Attr("min: float = -6.0") .Attr("max: float = 6.0") diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 0890d5fc7c..35f965b6a9 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -472,63 +472,6 @@ use_locking: If True, the operation will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); -namespace { - -Status ScatterNdUpdateShape(InferenceContext* c) { - ShapeHandle ref_shape = c->input(0); - ShapeHandle indices_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); - ShapeHandle updates_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); - - if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { - const int64 outer_dims = c->Rank(indices_shape) - 1; - const DimensionHandle ixdim = c->Dim(indices_shape, -1); - - // We can only do more validation if the last dimension of indices - // is a known value. - if (c->ValueKnown(ixdim)) { - int64 ix = c->Value(ixdim); - ShapeHandle unused; - ShapeHandle prefix_indices; - TF_RETURN_IF_ERROR( - c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); - ShapeHandle prefix_updates; - TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); - - Status s = c->Merge(prefix_indices, prefix_updates, &unused); - if (!s.ok()) { - return errors::InvalidArgument( - "The outer ", outer_dims, " dimensions of indices.shape=", - c->DebugString(indices_shape), "must match the outer ", outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); - } - - ShapeHandle suffix_ref; - TF_RETURN_IF_ERROR(c->Subshape(ref_shape, ix, &suffix_ref)); - ShapeHandle suffix_updates; - TF_RETURN_IF_ERROR( - c->Subshape(updates_shape, outer_dims, &suffix_updates)); - s = c->Merge(suffix_ref, suffix_updates, &unused); - if (!s.ok()) { - return errors::InvalidArgument( - "The inner ", c->Rank(ref_shape) - ix, " dimensions of ref.shape=", - c->DebugString(ref_shape), "must match the inner ", - c->Rank(updates_shape) - outer_dims, - " dimensions of updates.shape=", c->DebugString(updates_shape), - ": ", s.error_message()); - } - } - } - - c->set_output(0, ref_shape); - return Status::OK(); -} - -} // namespace - REGISTER_OP("ScatterNdUpdate") .Input("ref: Ref(T)") .Input("indices: Tindices") @@ -537,7 +480,7 @@ REGISTER_OP("ScatterNdUpdate") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .SetShapeFn(ScatterNdUpdateShape) + .SetShapeFn(shape_inference::ScatterNdUpdateShape) .Doc(R"doc( Applies sparse `updates` to individual values or slices within a given variable according to `indices`. @@ -596,7 +539,7 @@ REGISTER_OP("ScatterNdAdd") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(ScatterNdUpdateShape) + .SetShapeFn(shape_inference::ScatterNdUpdateShape) .Doc(R"doc( Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`. @@ -653,7 +596,7 @@ REGISTER_OP("ScatterNdSub") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .SetShapeFn(ScatterNdUpdateShape) + .SetShapeFn(shape_inference::ScatterNdUpdateShape) .Doc(R"doc( Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. @@ -713,7 +656,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want // .Attr("T: numbertype") // .Attr("Tindices: {int32, int64}") // .Attr("use_locking: bool = false") -// .SetShapeFn(ScatterNdUpdateShape) +// .SetShapeFn(shape_inference::ScatterNdUpdateShape) // .Doc( // R"doc(Applies sparse subtraction between `updates` and individual // values or slices within a given variable according to `indices`. @@ -769,7 +712,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want // .Attr("T: numbertype") // .Attr("Tindices: {int32, int64}") // .Attr("use_locking: bool = false") -// .SetShapeFn(ScatterNdUpdateShape) +// .SetShapeFn(shape_inference::ScatterNdUpdateShape) // .Doc( // R"doc(Applies sparse subtraction between `updates` and individual // values or slices within a given variable according to `indices`. diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index 8519d19fe1..ebc5686212 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -87,7 +87,7 @@ def _NumpyDiv(ref, indices, updates): return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u) -class ScatterNdTest(test.TestCase): +class StatefulScatterNdTest(test.TestCase): def _VariableRankTest(self, np_scatter, @@ -261,10 +261,6 @@ class ScatterNdTest(test.TestCase): indices = array_ops.zeros([2, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = np.array([2, 2, 2]) - self.assertAllEqual( - array_ops.scatter_nd(indices, updates, shape).get_shape().as_list(), - shape) - ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) self.assertAllEqual( state_ops.scatter_nd_update(ref, indices, @@ -274,37 +270,120 @@ class ScatterNdTest(test.TestCase): indices = array_ops.zeros([1, 1, 2], dtypes.int32) updates = array_ops.zeros([1, 1], dtypes.int32) shape = np.array([2, 2]) - scatter = array_ops.scatter_nd(indices, updates, shape) - self.assertAllEqual(scatter.get_shape().as_list(), shape) - expected_result = np.zeros([2, 2], dtype=np.int32) - with self.test_session(): - self.assertAllEqual(expected_result, scatter.eval()) - ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) scatter_update = state_ops.scatter_nd_update(ref, indices, updates) self.assertAllEqual(scatter_update.get_shape().as_list(), shape) + expected_result = np.zeros([2, 2], dtype=np.int32) with self.test_session(): ref.initializer.run() self.assertAllEqual(expected_result, scatter_update.eval()) + def testRank3InvalidShape1(self): + indices = array_ops.zeros([3, 2, 2], dtypes.int32) + updates = array_ops.zeros([2, 2, 2], dtypes.int32) + shape = np.array([2, 2, 2]) + ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) + with self.assertRaisesWithPredicateMatch( + ValueError, "The outer \\d+ dimensions of indices\\.shape="): + state_ops.scatter_nd_update(ref, indices, updates) + + def testRank3InvalidShape2(self): + indices = array_ops.zeros([2, 2, 1], dtypes.int32) + updates = array_ops.zeros([2, 2], dtypes.int32) + shape = np.array([2, 2, 2]) + ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) + with self.assertRaisesWithPredicateMatch( + ValueError, "The inner \\d+ dimensions of input\\.shape="): + state_ops.scatter_nd_update(ref, indices, updates) + + def testConcurrentUpdates(self): + num_updates = 10000 + update_values = np.random.rand(num_updates) + ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64) + indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32) + updates = constant_op.constant(update_values, dtype=dtypes.float64) + + expected_result = np.zeros([2, 2], dtype=np.float64) + expected_result[0, 1] = np.sum(update_values) + + scatter = state_ops.scatter_nd_add(ref, indices, updates) + init = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init) + result = sess.run(scatter) + assert np.allclose(result, expected_result) + + # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU. + def _disabledTestScatterOutOfRangeGpu(self): + if not test.IsBuiltWithCuda(): + return + # TODO(simister): Re-enable once binary size increase due to + # scatter_nd ops is under control. + # tf.scatter_nd_mul, tf.scatter_nd_div, + for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub, + state_ops.scatter_nd_update): + params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) + updates = np.array([-3, -4, -5]).astype(np.float32) + # With GPU, the code ignores indices that are out of range. + # We don't test the implementation; just test there's no failures. + with self.test_session(force_gpu=True): + ref = variables.Variable(params) + ref.initializer.run() + + # Indices all in range, no problem. + indices = np.array([2, 0, 5]) + op(ref, indices, updates).eval() + + # Indicies out of range should not fail. + indices = np.array([-1, 0, 5]) + op(ref, indices, updates).eval() + indices = np.array([2, 0, 6]) + op(ref, indices, updates).eval() + + +class ScatterNdTest(test.TestCase): + non_aliasing_add_test = False + + def scatter_nd(self, indices, updates, shape, input_=None): + del input_ # input_ is not used in scatter_nd + return array_ops.scatter_nd(indices, updates, shape) + + def testRank3ValidShape(self): + indices = array_ops.zeros([2, 2, 2], dtypes.int32) + updates = array_ops.zeros([2, 2, 2], dtypes.int32) + shape = np.array([2, 2, 2]) + self.assertAllEqual( + self.scatter_nd(indices, updates, shape).get_shape().as_list(), shape) + + def testExtraIndicesDimensions(self): + indices = array_ops.zeros([1, 1, 2], dtypes.int32) + updates = array_ops.zeros([1, 1], dtypes.int32) + shape = np.array([2, 2]) + scatter = self.scatter_nd(indices, updates, shape) + self.assertAllEqual(scatter.get_shape().as_list(), shape) + expected_result = np.zeros([2, 2], dtype=np.int32) + with self.test_session(): + self.assertAllEqual(expected_result, scatter.eval()) + def testUndefinedIndicesShape(self): indices = array_ops.placeholder(dtypes.int32, shape=None) updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) shape = constant_op.constant([2, 2, 2], dtypes.int32) - array_ops.scatter_nd(indices, updates, shape) + self.scatter_nd(indices, updates, shape) def testUndefinedUpdatesShape(self): indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) updates = array_ops.placeholder(dtypes.int32, shape=None) shape = constant_op.constant([2, 2, 2], dtypes.int32) - array_ops.scatter_nd(indices, updates, shape) + self.scatter_nd(indices, updates, shape) def testUndefinedOutputShape(self): indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) shape = array_ops.placeholder(dtypes.int32, shape=[None]) - array_ops.scatter_nd(indices, updates, shape) + self.scatter_nd(indices, updates, shape) def testEmptyOutputShape1(self): indices = array_ops.zeros([2, 2, 2], dtypes.int32) @@ -313,7 +392,7 @@ class ScatterNdTest(test.TestCase): with self.assertRaisesWithPredicateMatch( ValueError, "Indices and updates specified for empty output shape"): - array_ops.scatter_nd(indices, updates, shape) + self.scatter_nd(indices, updates, shape) def testEmptyOutputShape2(self): indices = array_ops.placeholder(dtypes.int32, shape=None) @@ -321,18 +400,18 @@ class ScatterNdTest(test.TestCase): shape = constant_op.constant([0, 3, 2], dtypes.int32) with self.test_session(): - array_ops.scatter_nd(indices, updates, shape).eval(feed_dict={ - indices: np.zeros( - [2, 2, 2], dtype=np.int32), - updates: np.zeros( - [2, 2, 2], dtype=np.int32) - }) + with self.assertRaisesOpError( + "Indices and updates specified for empty output"): + self.scatter_nd(indices, updates, shape).eval(feed_dict={ + indices: np.zeros([2, 2, 2], dtype=np.int32), + updates: np.zeros([2, 2, 2], dtype=np.int32) + }) def testEmptyOutputShape3(self): indices = array_ops.zeros([0], dtypes.int32) updates = array_ops.zeros([0], dtypes.int32) shape = constant_op.constant([0], dtypes.int32) - scatter = array_ops.scatter_nd(indices, updates, shape) + scatter = self.scatter_nd(indices, updates, shape) with self.test_session(): self.assertEqual(scatter.eval().size, 0) @@ -343,49 +422,49 @@ class ScatterNdTest(test.TestCase): shape = np.array([2, 2, 2]) with self.assertRaisesWithPredicateMatch( ValueError, "The outer \\d+ dimensions of indices\\.shape="): - array_ops.scatter_nd(indices, updates, shape) - - ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) - with self.assertRaisesWithPredicateMatch( - ValueError, "The outer \\d+ dimensions of indices\\.shape="): - state_ops.scatter_nd_update(ref, indices, updates) + self.scatter_nd(indices, updates, shape) def testRank3InvalidShape2(self): indices = array_ops.zeros([2, 2, 1], dtypes.int32) updates = array_ops.zeros([2, 2], dtypes.int32) shape = np.array([2, 2, 2]) with self.assertRaisesWithPredicateMatch( - ValueError, "The inner \\d+ dimensions of output\\.shape="): - array_ops.scatter_nd(indices, updates, shape) - - ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) - with self.assertRaisesWithPredicateMatch( - ValueError, "The inner \\d+ dimensions of ref\\.shape="): - state_ops.scatter_nd_update(ref, indices, updates) + ValueError, "The inner \\d+ dimensions of (input|output)\\.shape="): + self.scatter_nd(indices, updates, shape) def testGradientsRank2ElementUpdate(self): indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32) updates = constant_op.constant([1, 4], dtype=dtypes.float64) shape = constant_op.constant([2, 2], dtype=dtypes.int32) - outputs = array_ops.scatter_nd(indices, updates, shape) + input_ = array_ops.zeros(shape, dtype=dtypes.float64) + outputs = self.scatter_nd(indices, updates, shape, input_) grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64) - grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0] - expected_grads = np.array([1, 4], dtype=np.float64) + updates_grad, input_grad = gradients_impl.gradients( + [outputs], [updates, input_], [grad_vals]) + expected_updates_grad = np.array([1, 4], dtype=np.float64) + expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64) with self.test_session(): - self.assertAllEqual(expected_grads, grads.eval()) + self.assertAllEqual(expected_updates_grad, updates_grad.eval()) + if self.non_aliasing_add_test: + self.assertAllEqual(expected_input_grad, input_grad.eval()) def testGradientsRank2SliceUpdate(self): indices = constant_op.constant([[1], [0]], dtype=dtypes.int32) updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64) shape = constant_op.constant([2, 2], dtype=dtypes.int32) - outputs = array_ops.scatter_nd(indices, updates, shape) + input_ = array_ops.zeros(shape, dtype=dtypes.float64) + outputs = self.scatter_nd(indices, updates, shape, input_) grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64) - grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0] - expected_grads = np.array([[1, 2], [3, 4]], dtype=np.float64) + updates_grad, input_grad = gradients_impl.gradients( + [outputs], [updates, input_], [grad_vals]) + expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64) + expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64) with self.test_session(): - self.assertAllEqual(expected_grads, grads.eval()) + self.assertAllEqual(expected_updates_grad, updates_grad.eval()) + if self.non_aliasing_add_test: + self.assertAllEqual(expected_input_grad, input_grad.eval()) def testGradientsRank3SliceUpdate(self): indices = constant_op.constant( @@ -393,67 +472,28 @@ class ScatterNdTest(test.TestCase): updates = constant_op.constant( [[[5, 7], [2, 4]], [[1, 3], [6, 8]]], dtype=dtypes.float64) shape = constant_op.constant([2, 2, 2], dtype=dtypes.int32) - outputs = array_ops.scatter_nd(indices, updates, shape) + input_ = array_ops.zeros(shape, dtype=dtypes.float64) + outputs = self.scatter_nd(indices, updates, shape, input_) grad_vals = constant_op.constant( [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64) - grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0] - expected_grads = np.array( + updates_grad, input_grad = gradients_impl.gradients( + [outputs], [updates, input_], [grad_vals]) + expected_updates_grad = np.array( [[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64) + expected_input_grad = np.array( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64) with self.test_session(): - self.assertAllEqual(expected_grads, grads.eval()) - - def testConcurrentUpdates(self): - num_updates = 10000 - update_values = np.random.rand(num_updates) - ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64) - indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32) - updates = constant_op.constant(update_values, dtype=dtypes.float64) - - expected_result = np.zeros([2, 2], dtype=np.float64) - expected_result[0, 1] = np.sum(update_values) - - scatter = state_ops.scatter_nd_add(ref, indices, updates) - init = variables.global_variables_initializer() - - with session.Session() as sess: - sess.run(init) - result = sess.run(scatter) - assert np.allclose(result, expected_result) - - # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU. - def _disabledTestScatterOutOfRangeGpu(self): - if not test.IsBuiltWithCuda(): - return - # TODO(simister): Re-enable once binary size increase due to - # scatter_nd ops is under control. - # tf.scatter_nd_mul, tf.scatter_nd_div, - for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub, - state_ops.scatter_nd_update): - params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) - updates = np.array([-3, -4, -5]).astype(np.float32) - # With GPU, the code ignores indices that are out of range. - # We don't test the implementation; just test there's no failures. - with self.test_session(force_gpu=True): - ref = variables.Variable(params) - ref.initializer.run() - - # Indices all in range, no problem. - indices = np.array([2, 0, 5]) - op(ref, indices, updates).eval() - - # Indicies out of range should not fail. - indices = np.array([-1, 0, 5]) - op(ref, indices, updates).eval() - indices = np.array([2, 0, 6]) - op(ref, indices, updates).eval() + self.assertAllEqual(expected_updates_grad, updates_grad.eval()) + if self.non_aliasing_add_test: + self.assertAllEqual(expected_input_grad, input_grad.eval()) def testScatterNdRepatedIndicesAdd(self): indices = array_ops.zeros([100000, 1], dtypes.int32) values = np.random.randn(100000) shape = [1] with self.test_session(): - val = array_ops.scatter_nd(indices, values, shape).eval() + val = self.scatter_nd(indices, values, shape).eval() self.assertAllClose([np.sum(values)], val) def testSmokeScatterNdBatch2DSliceDim2(self): @@ -461,28 +501,37 @@ class ScatterNdTest(test.TestCase): indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32) values = array_ops.zeros([3, 5, 7]) shape = [4, 6, 7] - array_ops.scatter_nd(indices, values, shape).eval() + self.scatter_nd(indices, values, shape).eval() def testSmokeScatterNdBatch1DSliceDim2(self): with self.test_session(): indices = array_ops.zeros([0, 2], dtype=dtypes.int32) values = array_ops.zeros([0, 7]) shape = [4, 6, 7] - array_ops.scatter_nd(indices, values, shape).eval() + self.scatter_nd(indices, values, shape).eval() def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self): with self.test_session(): indices = array_ops.zeros([1, 3], dtype=dtypes.int32) values = array_ops.zeros([1, 6, 7, 8, 9]) shape = [3, 4, 5, 6, 7, 8, 9] - array_ops.scatter_nd(indices, values, shape).eval() + self.scatter_nd(indices, values, shape).eval() def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self): with self.test_session(): indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32) values = array_ops.zeros([1, 2, 6, 7, 8, 9]) shape = [3, 4, 5, 6, 7, 8, 9] - array_ops.scatter_nd(indices, values, shape).eval() + self.scatter_nd(indices, values, shape).eval() + + +class ScatterNdNonAliasingAddTest(ScatterNdTest): + non_aliasing_add_test = True + + def scatter_nd(self, indices, updates, shape, input_=None): + input_ = (input_ if input_ is not None else array_ops.zeros( + shape, dtype=updates.dtype)) + return array_ops.scatter_nd_non_aliasing_add(input_, indices, updates) if __name__ == "__main__": diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 5c6d309e6c..dcae3e0c7b 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -670,3 +670,10 @@ def _ScatterNdGrad(op, grad): indices = op.inputs[0] updates_grad = array_ops.gather_nd(grad, indices) return [None, updates_grad, None] + + +@ops.RegisterGradient("ScatterNdNonAliasingAdd") +def _ScatterNdNonAliasingAddGrad(op, grad): + indices = op.inputs[1] + updates_grad = array_ops.gather_nd(grad, indices) + return [grad, None, updates_grad] |