diff options
author | Alexandre Passos <apassos@google.com> | 2018-05-14 14:58:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-14 15:01:28 -0700 |
commit | 0775f684c51b6b2f24d58c116cc2073d53659e3c (patch) | |
tree | 4b3be72ca16c4ea88b5d2e35121447050790cb73 | |
parent | 5de9b8463ee214a02aa71815c837b49c6ea2f93c (diff) |
Do shape validation in ScatterNd kernel, not just the shape inference function.
Fixes #18648
PiperOrigin-RevId: 196572262
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.cc | 47 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/scatter_nd_ops_test.py | 12 |
2 files changed, 57 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 0caa7bd317..8ef6e77398 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -62,14 +62,57 @@ class ScatterNdOp : public OpKernel { const Tensor& updates = c->input(1); const Tensor& shape_input = c->input(2); - OP_REQUIRES(c, shape_input.dims() == 1, - errors::InvalidArgument("Shape must be a vector")); + OP_REQUIRES(c, indices.shape().dims() >= 1, + errors::InvalidArgument( + "Indices shape must have rank at least one. Found:", + indices.shape().DebugString())); + OP_REQUIRES(c, updates.shape().dims() >= 1, + errors::InvalidArgument( + "Updates shape must have rank at least one. Found:", + updates.shape().DebugString())); auto vec = shape_input.flat<Index>(); TensorShape shape; OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape)); + OP_REQUIRES( + c, + (shape.num_elements() > 0 || (indices.shape().num_elements() == 0 && + updates.shape().num_elements() == 0)), + errors::InvalidArgument( + "Indices and updates specified for empty output shape")); + + const int64 outer_dims = indices.shape().dims() - 1; + + for (int i = 0; i < outer_dims; ++i) { + OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i), + errors::InvalidArgument( + "Outer dimensions of indices and update must match. " + "Indices shape: ", + indices.shape().DebugString(), + ", updates shape:", updates.shape().DebugString())); + } + + const int64 ix = indices.shape().dim_size(outer_dims); + OP_REQUIRES( + c, updates.shape().dims() - outer_dims == shape.dims() - ix, + errors::InvalidArgument("Inner dimensions of output shape must match " + "inner dimensions of updates shape. Output: ", + shape.DebugString(), + " updates: ", updates.shape().DebugString())); + for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) { + OP_REQUIRES( + c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i), + errors::InvalidArgument( + "The inner ", shape.dims() - ix, + " dimensions of output.shape=", shape.DebugString(), + " must match the inner ", updates.shape().dims() - outer_dims, + " dimensions of updates.shape=", updates.shape().DebugString())); + } + OP_REQUIRES(c, shape_input.dims() == 1, + errors::InvalidArgument("Shape must be a vector")); + Tensor out; OP_REQUIRES_OK( c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>( diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index b7477a768a..79fe927b8a 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -23,8 +23,11 @@ import functools import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops @@ -364,6 +367,15 @@ class ScatterNdTest(test.TestCase): del input_ # input_ is not used in scatter_nd return array_ops.scatter_nd(indices, updates, shape) + @test_util.run_in_graph_and_eager_modes() + def testInvalidShape(self): + # TODO(apassos) figure out how to unify these errors + with self.assertRaises(errors.InvalidArgumentError + if context.executing_eagerly() else ValueError): + array_ops.scatter_nd(indices=[0], # this should be indices=[[0]] + updates=[0.0], + shape=[1]) + def testString(self): indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) |