aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-06-27 15:55:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-27 16:06:46 -0700
commit63ac86ae19933bc9386233cb899d653f388f1bef (patch)
tree0a73c888050d780e1d5d2014746c9a84d0b795f4 /tensorflow
parent494f5c0c95d16577844cc430b59e611e88750c29 (diff)
Fix bugs in ScatterNd and add ScatterNdNonAliasingAdd.
tf.scatter_nd_non_aliasing_add acts similarly to tf.scatter_nd_add but works on non-ref objects (i.e., Tensors -- not Variables). This means it has a gradient with respect to the primary input as well as the updates. It does its best to avoid making extra copies of the input. PiperOrigin-RevId: 160339328
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/cc/gradients/array_grad.cc11
-rw-r--r--tensorflow/cc/gradients/array_grad_test.cc22
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc62
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/kernels/BUILD22
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc149
-rw-r--r--tensorflow/core/ops/array_ops.cc56
-rw-r--r--tensorflow/core/ops/state_ops.cc67
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py243
-rw-r--r--tensorflow/python/ops/array_grad.py7
10 files changed, 431 insertions, 211 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index 37f07e71a0..e69db10580 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -247,6 +247,17 @@ Status ScatterNdGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
+Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto indices = op.input(1);
+ grad_outputs->push_back(Identity(scope, grad_inputs[0]));
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
+
Status PadGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc
index 5798b5b509..1777e18145 100644
--- a/tensorflow/cc/gradients/array_grad_test.cc
+++ b/tensorflow/cc/gradients/array_grad_test.cc
@@ -233,6 +233,28 @@ TEST_F(ArrayGradTest, ScatterNdGrad_SliceIndexing) {
RunTest(updates, updates_shape, y, y_shape);
}
+TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SimpleIndexing) {
+ TensorShape updates_shape({4});
+ TensorShape input_shape({8});
+ auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape));
+ auto updates =
+ Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape));
+ auto indices = Const(scope_, {{4}, {3}, {1}, {7}});
+ auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates);
+ RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape});
+}
+
+TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SliceIndexing) {
+ TensorShape updates_shape({2, 4, 4});
+ TensorShape input_shape({4, 4, 4});
+ auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape));
+ auto updates =
+ Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape));
+ auto indices = Const(scope_, {{0}, {2}});
+ auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates);
+ RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape});
+}
+
TEST_F(ArrayGradTest, PadGrad) {
TensorShape x_shape({2, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 035bceb640..dc2db3d395 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -922,5 +922,67 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
return Status::OK();
}
+Status ScatterNdUpdateShape(InferenceContext* c) {
+ ShapeHandle input_shape = c->input(0);
+ ShapeHandle indices_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
+ ShapeHandle updates_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
+
+ if (c->Value(c->NumElements(input_shape)) == 0 &&
+ (c->Value(c->NumElements(indices_shape)) > 0 ||
+ c->Value(c->NumElements(updates_shape)) > 0)) {
+ return errors::InvalidArgument(
+ "Indices and updates specified for empty output shape");
+ }
+
+ if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
+ const int64 num_outer_dims = c->Rank(indices_shape) - 1;
+ const DimensionHandle index_size = c->Dim(indices_shape, -1);
+
+ // We can only do more validation if the last dimension of indices
+ // is a known value.
+ if (c->ValueKnown(index_size)) {
+ const int64 ix = c->Value(index_size);
+ ShapeHandle unused;
+ ShapeHandle prefix_indices;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
+ ShapeHandle prefix_updates;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
+
+ Status s = c->Merge(prefix_indices, prefix_updates, &unused);
+ if (!s.ok()) {
+ return errors::InvalidArgument(
+ "The outer ", num_outer_dims,
+ " dimensions of indices.shape=", c->DebugString(indices_shape),
+ " must match the outer ", num_outer_dims,
+ " dimensions of updates.shape=", c->DebugString(updates_shape),
+ ": ", s.error_message());
+ }
+
+ ShapeHandle input_suffix;
+ TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
+ ShapeHandle suffix_updates;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
+ s = c->Merge(input_suffix, suffix_updates, &unused);
+ if (!s.ok()) {
+ return errors::InvalidArgument(
+ "The inner ", c->Rank(input_shape) - ix,
+ " dimensions of input.shape=", c->DebugString(input_shape),
+ " must match the inner ", c->Rank(updates_shape) - num_outer_dims,
+ " dimensions of updates.shape=", c->DebugString(updates_shape),
+ ": ", s.error_message());
+ }
+ }
+ }
+
+ c->set_output(0, input_shape);
+ return Status::OK();
+}
+
} // namespace shape_inference
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index dc99e48adb..aafd9eff27 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -207,6 +207,9 @@ Status RandomShape(shape_inference::InferenceContext* c);
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape);
+// Shape function for ScatterNd update/add/sub/... operations.
+Status ScatterNdUpdateShape(InferenceContext* c);
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 8b7c269a11..41995a2ff5 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3475,8 +3475,26 @@ tf_kernel_library(
tf_kernel_library(
name = "scatter_nd_op",
- prefix = "scatter_nd_op",
- deps = STATE_DEPS,
+ srcs = [
+ "scatter_nd_op.cc",
+ "scatter_nd_op_cpu_impl_0.cc",
+ "scatter_nd_op_cpu_impl_1.cc",
+ "scatter_nd_op_cpu_impl_2.cc",
+ "scatter_nd_op_cpu_impl_3.cc",
+ "scatter_nd_op_cpu_impl_4.cc",
+ "scatter_nd_op_cpu_impl_5.cc",
+ ],
+ hdrs = [
+ "dense_update_ops.h",
+ "scatter_nd_op.h",
+ "scatter_nd_op_cpu_impl.h",
+ ],
+ gpu_srcs = [
+ "dense_update_ops.h",
+ "scatter_nd_op.h",
+ "scatter_nd_op_gpu.cu.cc",
+ ],
+ deps = STATE_DEPS + [":dense_update_ops"],
)
tf_kernel_library(
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 48565d8cb9..5f96dfa0f3 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -16,12 +16,17 @@ limitations under the License.
// See docs in ../ops/state_ops.cc.
#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
#include "tensorflow/core/kernels/scatter_nd_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/dense_update_ops.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -29,7 +34,7 @@ limitations under the License.
#ifdef TENSORFLOW_USE_SYCL
#include "tensorflow/core/common_runtime/sycl/sycl_util.h"
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
namespace tensorflow {
@@ -37,7 +42,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
// Check whether updates.shape = indices.shape[:batch_dim] +
// params_shape[slice_dim:]
@@ -91,11 +96,13 @@ static void PrepareAndValidateInputs(OpKernelContext* c,
errors::InvalidArgument("Output must be at least 1-D, ",
"got shape: ", params_shape.DebugString()));
- OP_REQUIRES(c,
- params_shape.num_elements() >= 0 ||
- (indices.NumElements() == 0 && updates.NumElements() == 0),
- errors::InvalidArgument(
- "Indices and updates specified for empty output", " shape"));
+ OP_REQUIRES(
+ c,
+ params_shape.num_elements() > 0 ||
+ (indices.NumElements() == 0 && updates.NumElements() == 0),
+ errors::InvalidArgument(
+ "Indices and updates specified for empty output. indices shape: ",
+ indices.shape().DebugString()));
OP_REQUIRES(c, updates.dim_size(0) == indices.dim_size(0),
errors::InvalidArgument(
@@ -147,9 +154,9 @@ static void PrepareAndValidateInputs(OpKernelContext* c,
template <typename Device, typename Index>
class IndexFlattener {
-public:
- inline typename TTypes<Index, 2>::ConstTensor
- operator()(OpKernelContext*, const Tensor& indices) {
+ public:
+ inline typename TTypes<Index, 2>::ConstTensor operator()(
+ OpKernelContext*, const Tensor& indices) {
return indices.flat_inner_dims<Index>();
}
};
@@ -157,12 +164,12 @@ public:
#ifdef TENSORFLOW_USE_SYCL
template <typename Index>
class IndexFlattener<SYCLDevice, Index> {
-public:
+ public:
IndexFlattener() { indices_host_ = nullptr; }
~IndexFlattener() { delete[] indices_host_; }
- inline typename TTypes<Index, 2>::ConstTensor
- operator()(OpKernelContext* c, const Tensor& indices) {
+ inline typename TTypes<Index, 2>::ConstTensor operator()(
+ OpKernelContext* c, const Tensor& indices) {
size_t num_indices = indices.NumElements();
indices_host_ = new Index[num_indices];
auto device = c->eigen_sycl_device();
@@ -170,11 +177,11 @@ public:
auto src_ptr = GetBase(&indices);
device.memcpyDeviceToHost(indices_host_, static_cast<const Index*>(src_ptr),
size);
- return typename TTypes<Index, 2>::ConstTensor(indices_host_,
- indices.shape().AsEigenDSizes<2>());
+ return typename TTypes<Index, 2>::ConstTensor(
+ indices_host_, indices.shape().AsEigenDSizes<2>());
}
-private:
+ private:
Index* indices_host_;
};
#endif
@@ -213,6 +220,9 @@ class ScatterNdOp : public OpKernel {
Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out));
+
+ if (shape.num_elements() == 0) return;
+
functor::SetZeroFunctor<Device, T> fill;
fill(c->eigen_device<Device>(), out->flat<T>());
auto output_matrix = out->template shaped<T, 2>(
@@ -271,12 +281,19 @@ class ScatterNdUpdateOp : public OpKernel {
const DataType dt = DataTypeToEnum<T>::v();
const DataType dt_ref = DataTypeToEnum<T>::ref();
const DataType index_t = DataTypeToEnum<Index>::v();
- OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
- OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
+ if (IsRefType(c->input_type(0))) {
+ OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
+ OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
+ } else {
+ OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
+ use_exclusive_lock_ = false;
+ }
}
void Compute(OpKernelContext* c) override {
if (use_exclusive_lock_) {
+ // If we're here, it means the input type is a ref.
+ DCHECK(IsRefType(c->input_dtype(0)));
// Hold mutex while we apply updates
mutex_lock l(*c->input_ref_mutex(0));
DoCompute(c);
@@ -289,20 +306,41 @@ class ScatterNdUpdateOp : public OpKernel {
bool use_exclusive_lock_;
void DoCompute(OpKernelContext* c) {
- Tensor params = c->mutable_input(0, use_exclusive_lock_);
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
- const TensorShape& params_shape(params.shape());
int64 slice_dim;
Index num_updates;
Index slice_size;
- OP_REQUIRES(c, params.IsInitialized(),
- errors::FailedPrecondition("Null ref for params"));
+ Tensor params;
+ TensorShape params_shape;
+
+ if (IsRefType(c->input_dtype(0))) {
+ params = c->mutable_input(0, use_exclusive_lock_);
+ params_shape = params.shape();
+ c->forward_ref_input_to_ref_output(0, 0);
+ OP_REQUIRES(c, params.IsInitialized(),
+ errors::FailedPrecondition("Null ref for params"));
+ } else {
+ Tensor* params_ptr;
+ params_shape = c->input(0).shape();
+ if (!c->forward_input_to_output_with_shape(0, 0, params_shape,
+ &params_ptr)) {
+ // We weren't able to forward the input to output, so just
+ // allocate a new output tensor and copy the values over.
+ OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, &params_ptr));
+ params = *params_ptr;
+ functor::DenseUpdate<Device, T, ASSIGN> copy;
+ const Tensor& input_copy = c->input(0);
+ copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>());
+ }
+ }
+
PrepareAndValidateInputs<Index>(c, params_shape, indices, updates,
&slice_dim, &num_updates, &slice_size);
if (!c->status().ok()) return;
+ if (params_shape.num_elements() == 0) return;
IndexFlattener<Device, Index> index_flattener;
auto indices_flat = index_flattener(c, indices);
@@ -310,7 +348,6 @@ class ScatterNdUpdateOp : public OpKernel {
auto params_matrix = params.template shaped<T, 2>(
{params_shape.num_elements() / slice_size, slice_size});
Index bad_i = -1;
- c->forward_ref_input_to_ref_output(0, 0);
switch (slice_dim) {
#define PARAMS_CASE(IXDIM) \
@@ -376,10 +413,12 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
-#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
- REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
- scatter_nd_op::UpdateOp::ADD); \
- REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
+#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
+ scatter_nd_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
+ scatter_nd_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
scatter_nd_op::UpdateOp::SUB);
// TODO(simister): Find a way to reduce amount of templated generated code
// to reduce build size, then re-enable these additional operations.
@@ -421,9 +460,31 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_GPU);
-
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_GPU);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \
+ REGISTER_SCATTER_ND_ADD_SUB(type, SYCL);
+
+#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
+ REGISTER_SCATTER_ND_UPDATE(type, SYCL);
+
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
+#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
+#undef REGISTER_SCATTER_ND_UPDATE_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
+#undef REGISTER_SCATTER_ND_ADD
+#undef REGISTER_SCATTER_ND_ADD_SUB
+#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
+#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
+#undef REGISTER_SCATTER_ND_UPDATE
+#undef REGISTER_SCATTER_ND_UPDATE_CPU
+#undef REGISTER_SCATTER_ND_UPDATE_GPU
+#undef REGISTER_SCATTER_ND_KERNEL
+#undef REGISTER_SCATTER_ND_KERNEL_INDEX
+
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
@@ -458,31 +519,19 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX
#undef DECLARE_GPU_SPECS_INDEX_OP
-} // namespace functor
-
-#endif // GOOGLE_CUDA
-#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \
- REGISTER_SCATTER_ND_ADD_SUB(type, SYCL);
+#define REGISTER_GPU_KERNELS(type) \
+ template <> \
+ void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
+ const GPUDevice& d, typename TTypes<type>::Flat lhs, \
+ typename TTypes<type>::ConstFlat rhs); \
+ extern template struct DenseUpdate<GPUDevice, type, ASSIGN>;
-#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
- REGISTER_SCATTER_ND_UPDATE(type, SYCL);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
-#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
-#undef REGISTER_SCATTER_ND_UPDATE_SYCL
-#endif // TENSORFLOW_USE_SYCL
+} // namespace functor
-#undef REGISTER_SCATTER_ND_ADD
-#undef REGISTER_SCATTER_ND_ADD_SUB
-#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
-#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
-#undef REGISTER_SCATTER_ND_UPDATE
-#undef REGISTER_SCATTER_ND_UPDATE_CPU
-#undef REGISTER_SCATTER_ND_UPDATE_GPU
-#undef REGISTER_SCATTER_ND_KERNEL
-#undef REGISTER_SCATTER_ND_KERNEL_INDEX
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index f27f8968d3..c2b81615bd 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -4912,6 +4912,62 @@ output: A new tensor with the given shape and updates applied according
to the indices.
)doc");
+REGISTER_OP("ScatterNdNonAliasingAdd")
+ .Input("input: T")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape)
+ .Doc(R"doc(
+Applies sparse addition to `input` using individual values or slices
+from `updates` according to indices `indices`. The updates are non-aliasing:
+`input` is only modified in-place if no other operations will use it.
+Otherwise, a copy of `input` is made. This operation has a gradient with
+respect to both `input` and `updates`.
+
+`input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `input`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or `(P-K)`-dimensional slices
+(if `K < P`) along the `K`th dimension of `input`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].
+```
+
+For example, say we want to add 4 scattered elements to a rank-1 tensor to 8
+elements. In Python, that addition would look like this:
+
+ input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1], [7]])
+ updates = tf.constant([9, 10, 11, 12])
+ output = tf.scatter_nd_non_aliasing_add(input, indices, updates)
+ with tf.Session() as sess:
+ print(sess.run(output))
+
+The resulting value `output` would look like this:
+
+ [1, 13, 3, 14, 14, 6, 7, 20]
+
+See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to
+slices.
+
+input: A Tensor.
+indices: A Tensor. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into `input`.
+updates: A Tensor. Must have the same type as ref. A tensor of updated values
+ to add to `input`.
+output: A `Tensor` with the same shape as `input`, containing values of `input`
+ updated with `updates`.
+)doc");
+
REGISTER_OP("FakeQuantWithMinMaxArgs")
.Attr("min: float = -6.0")
.Attr("max: float = 6.0")
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 0890d5fc7c..35f965b6a9 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -472,63 +472,6 @@ use_locking: If True, the operation will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
-namespace {
-
-Status ScatterNdUpdateShape(InferenceContext* c) {
- ShapeHandle ref_shape = c->input(0);
- ShapeHandle indices_shape;
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
- ShapeHandle updates_shape;
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
-
- if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
- const int64 outer_dims = c->Rank(indices_shape) - 1;
- const DimensionHandle ixdim = c->Dim(indices_shape, -1);
-
- // We can only do more validation if the last dimension of indices
- // is a known value.
- if (c->ValueKnown(ixdim)) {
- int64 ix = c->Value(ixdim);
- ShapeHandle unused;
- ShapeHandle prefix_indices;
- TF_RETURN_IF_ERROR(
- c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
- ShapeHandle prefix_updates;
- TF_RETURN_IF_ERROR(
- c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
-
- Status s = c->Merge(prefix_indices, prefix_updates, &unused);
- if (!s.ok()) {
- return errors::InvalidArgument(
- "The outer ", outer_dims, " dimensions of indices.shape=",
- c->DebugString(indices_shape), "must match the outer ", outer_dims,
- " dimensions of updates.shape=", c->DebugString(updates_shape),
- ": ", s.error_message());
- }
-
- ShapeHandle suffix_ref;
- TF_RETURN_IF_ERROR(c->Subshape(ref_shape, ix, &suffix_ref));
- ShapeHandle suffix_updates;
- TF_RETURN_IF_ERROR(
- c->Subshape(updates_shape, outer_dims, &suffix_updates));
- s = c->Merge(suffix_ref, suffix_updates, &unused);
- if (!s.ok()) {
- return errors::InvalidArgument(
- "The inner ", c->Rank(ref_shape) - ix, " dimensions of ref.shape=",
- c->DebugString(ref_shape), "must match the inner ",
- c->Rank(updates_shape) - outer_dims,
- " dimensions of updates.shape=", c->DebugString(updates_shape),
- ": ", s.error_message());
- }
- }
- }
-
- c->set_output(0, ref_shape);
- return Status::OK();
-}
-
-} // namespace
-
REGISTER_OP("ScatterNdUpdate")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
@@ -537,7 +480,7 @@ REGISTER_OP("ScatterNdUpdate")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
- .SetShapeFn(ScatterNdUpdateShape)
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape)
.Doc(R"doc(
Applies sparse `updates` to individual values or slices within a given
variable according to `indices`.
@@ -596,7 +539,7 @@ REGISTER_OP("ScatterNdAdd")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
- .SetShapeFn(ScatterNdUpdateShape)
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape)
.Doc(R"doc(
Applies sparse addition between `updates` and individual values or slices
within a given variable according to `indices`.
@@ -653,7 +596,7 @@ REGISTER_OP("ScatterNdSub")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
- .SetShapeFn(ScatterNdUpdateShape)
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape)
.Doc(R"doc(
Applies sparse subtraction between `updates` and individual values or slices
within a given variable according to `indices`.
@@ -713,7 +656,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want
// .Attr("T: numbertype")
// .Attr("Tindices: {int32, int64}")
// .Attr("use_locking: bool = false")
-// .SetShapeFn(ScatterNdUpdateShape)
+// .SetShapeFn(shape_inference::ScatterNdUpdateShape)
// .Doc(
// R"doc(Applies sparse subtraction between `updates` and individual
// values or slices within a given variable according to `indices`.
@@ -769,7 +712,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want
// .Attr("T: numbertype")
// .Attr("Tindices: {int32, int64}")
// .Attr("use_locking: bool = false")
-// .SetShapeFn(ScatterNdUpdateShape)
+// .SetShapeFn(shape_inference::ScatterNdUpdateShape)
// .Doc(
// R"doc(Applies sparse subtraction between `updates` and individual
// values or slices within a given variable according to `indices`.
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 8519d19fe1..ebc5686212 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -87,7 +87,7 @@ def _NumpyDiv(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
-class ScatterNdTest(test.TestCase):
+class StatefulScatterNdTest(test.TestCase):
def _VariableRankTest(self,
np_scatter,
@@ -261,10 +261,6 @@ class ScatterNdTest(test.TestCase):
indices = array_ops.zeros([2, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
- self.assertAllEqual(
- array_ops.scatter_nd(indices, updates, shape).get_shape().as_list(),
- shape)
-
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
self.assertAllEqual(
state_ops.scatter_nd_update(ref, indices,
@@ -274,37 +270,120 @@ class ScatterNdTest(test.TestCase):
indices = array_ops.zeros([1, 1, 2], dtypes.int32)
updates = array_ops.zeros([1, 1], dtypes.int32)
shape = np.array([2, 2])
- scatter = array_ops.scatter_nd(indices, updates, shape)
- self.assertAllEqual(scatter.get_shape().as_list(), shape)
- expected_result = np.zeros([2, 2], dtype=np.int32)
- with self.test_session():
- self.assertAllEqual(expected_result, scatter.eval())
-
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
+ expected_result = np.zeros([2, 2], dtype=np.int32)
with self.test_session():
ref.initializer.run()
self.assertAllEqual(expected_result, scatter_update.eval())
+ def testRank3InvalidShape1(self):
+ indices = array_ops.zeros([3, 2, 2], dtypes.int32)
+ updates = array_ops.zeros([2, 2, 2], dtypes.int32)
+ shape = np.array([2, 2, 2])
+ ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, "The outer \\d+ dimensions of indices\\.shape="):
+ state_ops.scatter_nd_update(ref, indices, updates)
+
+ def testRank3InvalidShape2(self):
+ indices = array_ops.zeros([2, 2, 1], dtypes.int32)
+ updates = array_ops.zeros([2, 2], dtypes.int32)
+ shape = np.array([2, 2, 2])
+ ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, "The inner \\d+ dimensions of input\\.shape="):
+ state_ops.scatter_nd_update(ref, indices, updates)
+
+ def testConcurrentUpdates(self):
+ num_updates = 10000
+ update_values = np.random.rand(num_updates)
+ ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64)
+ indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32)
+ updates = constant_op.constant(update_values, dtype=dtypes.float64)
+
+ expected_result = np.zeros([2, 2], dtype=np.float64)
+ expected_result[0, 1] = np.sum(update_values)
+
+ scatter = state_ops.scatter_nd_add(ref, indices, updates)
+ init = variables.global_variables_initializer()
+
+ with session.Session() as sess:
+ sess.run(init)
+ result = sess.run(scatter)
+ assert np.allclose(result, expected_result)
+
+ # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
+ def _disabledTestScatterOutOfRangeGpu(self):
+ if not test.IsBuiltWithCuda():
+ return
+ # TODO(simister): Re-enable once binary size increase due to
+ # scatter_nd ops is under control.
+ # tf.scatter_nd_mul, tf.scatter_nd_div,
+ for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
+ state_ops.scatter_nd_update):
+ params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
+ updates = np.array([-3, -4, -5]).astype(np.float32)
+ # With GPU, the code ignores indices that are out of range.
+ # We don't test the implementation; just test there's no failures.
+ with self.test_session(force_gpu=True):
+ ref = variables.Variable(params)
+ ref.initializer.run()
+
+ # Indices all in range, no problem.
+ indices = np.array([2, 0, 5])
+ op(ref, indices, updates).eval()
+
+ # Indicies out of range should not fail.
+ indices = np.array([-1, 0, 5])
+ op(ref, indices, updates).eval()
+ indices = np.array([2, 0, 6])
+ op(ref, indices, updates).eval()
+
+
+class ScatterNdTest(test.TestCase):
+ non_aliasing_add_test = False
+
+ def scatter_nd(self, indices, updates, shape, input_=None):
+ del input_ # input_ is not used in scatter_nd
+ return array_ops.scatter_nd(indices, updates, shape)
+
+ def testRank3ValidShape(self):
+ indices = array_ops.zeros([2, 2, 2], dtypes.int32)
+ updates = array_ops.zeros([2, 2, 2], dtypes.int32)
+ shape = np.array([2, 2, 2])
+ self.assertAllEqual(
+ self.scatter_nd(indices, updates, shape).get_shape().as_list(), shape)
+
+ def testExtraIndicesDimensions(self):
+ indices = array_ops.zeros([1, 1, 2], dtypes.int32)
+ updates = array_ops.zeros([1, 1], dtypes.int32)
+ shape = np.array([2, 2])
+ scatter = self.scatter_nd(indices, updates, shape)
+ self.assertAllEqual(scatter.get_shape().as_list(), shape)
+ expected_result = np.zeros([2, 2], dtype=np.int32)
+ with self.test_session():
+ self.assertAllEqual(expected_result, scatter.eval())
+
def testUndefinedIndicesShape(self):
indices = array_ops.placeholder(dtypes.int32, shape=None)
updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
shape = constant_op.constant([2, 2, 2], dtypes.int32)
- array_ops.scatter_nd(indices, updates, shape)
+ self.scatter_nd(indices, updates, shape)
def testUndefinedUpdatesShape(self):
indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
updates = array_ops.placeholder(dtypes.int32, shape=None)
shape = constant_op.constant([2, 2, 2], dtypes.int32)
- array_ops.scatter_nd(indices, updates, shape)
+ self.scatter_nd(indices, updates, shape)
def testUndefinedOutputShape(self):
indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
shape = array_ops.placeholder(dtypes.int32, shape=[None])
- array_ops.scatter_nd(indices, updates, shape)
+ self.scatter_nd(indices, updates, shape)
def testEmptyOutputShape1(self):
indices = array_ops.zeros([2, 2, 2], dtypes.int32)
@@ -313,7 +392,7 @@ class ScatterNdTest(test.TestCase):
with self.assertRaisesWithPredicateMatch(
ValueError, "Indices and updates specified for empty output shape"):
- array_ops.scatter_nd(indices, updates, shape)
+ self.scatter_nd(indices, updates, shape)
def testEmptyOutputShape2(self):
indices = array_ops.placeholder(dtypes.int32, shape=None)
@@ -321,18 +400,18 @@ class ScatterNdTest(test.TestCase):
shape = constant_op.constant([0, 3, 2], dtypes.int32)
with self.test_session():
- array_ops.scatter_nd(indices, updates, shape).eval(feed_dict={
- indices: np.zeros(
- [2, 2, 2], dtype=np.int32),
- updates: np.zeros(
- [2, 2, 2], dtype=np.int32)
- })
+ with self.assertRaisesOpError(
+ "Indices and updates specified for empty output"):
+ self.scatter_nd(indices, updates, shape).eval(feed_dict={
+ indices: np.zeros([2, 2, 2], dtype=np.int32),
+ updates: np.zeros([2, 2, 2], dtype=np.int32)
+ })
def testEmptyOutputShape3(self):
indices = array_ops.zeros([0], dtypes.int32)
updates = array_ops.zeros([0], dtypes.int32)
shape = constant_op.constant([0], dtypes.int32)
- scatter = array_ops.scatter_nd(indices, updates, shape)
+ scatter = self.scatter_nd(indices, updates, shape)
with self.test_session():
self.assertEqual(scatter.eval().size, 0)
@@ -343,49 +422,49 @@ class ScatterNdTest(test.TestCase):
shape = np.array([2, 2, 2])
with self.assertRaisesWithPredicateMatch(
ValueError, "The outer \\d+ dimensions of indices\\.shape="):
- array_ops.scatter_nd(indices, updates, shape)
-
- ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
- with self.assertRaisesWithPredicateMatch(
- ValueError, "The outer \\d+ dimensions of indices\\.shape="):
- state_ops.scatter_nd_update(ref, indices, updates)
+ self.scatter_nd(indices, updates, shape)
def testRank3InvalidShape2(self):
indices = array_ops.zeros([2, 2, 1], dtypes.int32)
updates = array_ops.zeros([2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
with self.assertRaisesWithPredicateMatch(
- ValueError, "The inner \\d+ dimensions of output\\.shape="):
- array_ops.scatter_nd(indices, updates, shape)
-
- ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
- with self.assertRaisesWithPredicateMatch(
- ValueError, "The inner \\d+ dimensions of ref\\.shape="):
- state_ops.scatter_nd_update(ref, indices, updates)
+ ValueError, "The inner \\d+ dimensions of (input|output)\\.shape="):
+ self.scatter_nd(indices, updates, shape)
def testGradientsRank2ElementUpdate(self):
indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32)
updates = constant_op.constant([1, 4], dtype=dtypes.float64)
shape = constant_op.constant([2, 2], dtype=dtypes.int32)
- outputs = array_ops.scatter_nd(indices, updates, shape)
+ input_ = array_ops.zeros(shape, dtype=dtypes.float64)
+ outputs = self.scatter_nd(indices, updates, shape, input_)
grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
- grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0]
- expected_grads = np.array([1, 4], dtype=np.float64)
+ updates_grad, input_grad = gradients_impl.gradients(
+ [outputs], [updates, input_], [grad_vals])
+ expected_updates_grad = np.array([1, 4], dtype=np.float64)
+ expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
with self.test_session():
- self.assertAllEqual(expected_grads, grads.eval())
+ self.assertAllEqual(expected_updates_grad, updates_grad.eval())
+ if self.non_aliasing_add_test:
+ self.assertAllEqual(expected_input_grad, input_grad.eval())
def testGradientsRank2SliceUpdate(self):
indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
shape = constant_op.constant([2, 2], dtype=dtypes.int32)
- outputs = array_ops.scatter_nd(indices, updates, shape)
+ input_ = array_ops.zeros(shape, dtype=dtypes.float64)
+ outputs = self.scatter_nd(indices, updates, shape, input_)
grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
- grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0]
- expected_grads = np.array([[1, 2], [3, 4]], dtype=np.float64)
+ updates_grad, input_grad = gradients_impl.gradients(
+ [outputs], [updates, input_], [grad_vals])
+ expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
+ expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64)
with self.test_session():
- self.assertAllEqual(expected_grads, grads.eval())
+ self.assertAllEqual(expected_updates_grad, updates_grad.eval())
+ if self.non_aliasing_add_test:
+ self.assertAllEqual(expected_input_grad, input_grad.eval())
def testGradientsRank3SliceUpdate(self):
indices = constant_op.constant(
@@ -393,67 +472,28 @@ class ScatterNdTest(test.TestCase):
updates = constant_op.constant(
[[[5, 7], [2, 4]], [[1, 3], [6, 8]]], dtype=dtypes.float64)
shape = constant_op.constant([2, 2, 2], dtype=dtypes.int32)
- outputs = array_ops.scatter_nd(indices, updates, shape)
+ input_ = array_ops.zeros(shape, dtype=dtypes.float64)
+ outputs = self.scatter_nd(indices, updates, shape, input_)
grad_vals = constant_op.constant(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64)
- grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0]
- expected_grads = np.array(
+ updates_grad, input_grad = gradients_impl.gradients(
+ [outputs], [updates, input_], [grad_vals])
+ expected_updates_grad = np.array(
[[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
+ expected_input_grad = np.array(
+ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64)
with self.test_session():
- self.assertAllEqual(expected_grads, grads.eval())
-
- def testConcurrentUpdates(self):
- num_updates = 10000
- update_values = np.random.rand(num_updates)
- ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64)
- indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32)
- updates = constant_op.constant(update_values, dtype=dtypes.float64)
-
- expected_result = np.zeros([2, 2], dtype=np.float64)
- expected_result[0, 1] = np.sum(update_values)
-
- scatter = state_ops.scatter_nd_add(ref, indices, updates)
- init = variables.global_variables_initializer()
-
- with session.Session() as sess:
- sess.run(init)
- result = sess.run(scatter)
- assert np.allclose(result, expected_result)
-
- # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
- def _disabledTestScatterOutOfRangeGpu(self):
- if not test.IsBuiltWithCuda():
- return
- # TODO(simister): Re-enable once binary size increase due to
- # scatter_nd ops is under control.
- # tf.scatter_nd_mul, tf.scatter_nd_div,
- for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
- state_ops.scatter_nd_update):
- params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
- updates = np.array([-3, -4, -5]).astype(np.float32)
- # With GPU, the code ignores indices that are out of range.
- # We don't test the implementation; just test there's no failures.
- with self.test_session(force_gpu=True):
- ref = variables.Variable(params)
- ref.initializer.run()
-
- # Indices all in range, no problem.
- indices = np.array([2, 0, 5])
- op(ref, indices, updates).eval()
-
- # Indicies out of range should not fail.
- indices = np.array([-1, 0, 5])
- op(ref, indices, updates).eval()
- indices = np.array([2, 0, 6])
- op(ref, indices, updates).eval()
+ self.assertAllEqual(expected_updates_grad, updates_grad.eval())
+ if self.non_aliasing_add_test:
+ self.assertAllEqual(expected_input_grad, input_grad.eval())
def testScatterNdRepatedIndicesAdd(self):
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
shape = [1]
with self.test_session():
- val = array_ops.scatter_nd(indices, values, shape).eval()
+ val = self.scatter_nd(indices, values, shape).eval()
self.assertAllClose([np.sum(values)], val)
def testSmokeScatterNdBatch2DSliceDim2(self):
@@ -461,28 +501,37 @@ class ScatterNdTest(test.TestCase):
indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
values = array_ops.zeros([3, 5, 7])
shape = [4, 6, 7]
- array_ops.scatter_nd(indices, values, shape).eval()
+ self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim2(self):
with self.test_session():
indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
values = array_ops.zeros([0, 7])
shape = [4, 6, 7]
- array_ops.scatter_nd(indices, values, shape).eval()
+ self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
with self.test_session():
indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
- array_ops.scatter_nd(indices, values, shape).eval()
+ self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
with self.test_session():
indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 2, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
- array_ops.scatter_nd(indices, values, shape).eval()
+ self.scatter_nd(indices, values, shape).eval()
+
+
+class ScatterNdNonAliasingAddTest(ScatterNdTest):
+ non_aliasing_add_test = True
+
+ def scatter_nd(self, indices, updates, shape, input_=None):
+ input_ = (input_ if input_ is not None else array_ops.zeros(
+ shape, dtype=updates.dtype))
+ return array_ops.scatter_nd_non_aliasing_add(input_, indices, updates)
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 5c6d309e6c..dcae3e0c7b 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -670,3 +670,10 @@ def _ScatterNdGrad(op, grad):
indices = op.inputs[0]
updates_grad = array_ops.gather_nd(grad, indices)
return [None, updates_grad, None]
+
+
+@ops.RegisterGradient("ScatterNdNonAliasingAdd")
+def _ScatterNdNonAliasingAddGrad(op, grad):
+ indices = op.inputs[1]
+ updates_grad = array_ops.gather_nd(grad, indices)
+ return [grad, None, updates_grad]