aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-05-14 14:58:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 15:01:28 -0700
commit0775f684c51b6b2f24d58c116cc2073d53659e3c (patch)
tree4b3be72ca16c4ea88b5d2e35121447050790cb73
parent5de9b8463ee214a02aa71815c837b49c6ea2f93c (diff)
Do shape validation in ScatterNd kernel, not just the shape inference function.
Fixes #18648 PiperOrigin-RevId: 196572262
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc47
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py12
2 files changed, 57 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 0caa7bd317..8ef6e77398 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -62,14 +62,57 @@ class ScatterNdOp : public OpKernel {
const Tensor& updates = c->input(1);
const Tensor& shape_input = c->input(2);
- OP_REQUIRES(c, shape_input.dims() == 1,
- errors::InvalidArgument("Shape must be a vector"));
+ OP_REQUIRES(c, indices.shape().dims() >= 1,
+ errors::InvalidArgument(
+ "Indices shape must have rank at least one. Found:",
+ indices.shape().DebugString()));
+ OP_REQUIRES(c, updates.shape().dims() >= 1,
+ errors::InvalidArgument(
+ "Updates shape must have rank at least one. Found:",
+ updates.shape().DebugString()));
auto vec = shape_input.flat<Index>();
TensorShape shape;
OP_REQUIRES_OK(c,
TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
+ OP_REQUIRES(
+ c,
+ (shape.num_elements() > 0 || (indices.shape().num_elements() == 0 &&
+ updates.shape().num_elements() == 0)),
+ errors::InvalidArgument(
+ "Indices and updates specified for empty output shape"));
+
+ const int64 outer_dims = indices.shape().dims() - 1;
+
+ for (int i = 0; i < outer_dims; ++i) {
+ OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
+ errors::InvalidArgument(
+ "Outer dimensions of indices and update must match. "
+ "Indices shape: ",
+ indices.shape().DebugString(),
+ ", updates shape:", updates.shape().DebugString()));
+ }
+
+ const int64 ix = indices.shape().dim_size(outer_dims);
+ OP_REQUIRES(
+ c, updates.shape().dims() - outer_dims == shape.dims() - ix,
+ errors::InvalidArgument("Inner dimensions of output shape must match "
+ "inner dimensions of updates shape. Output: ",
+ shape.DebugString(),
+ " updates: ", updates.shape().DebugString()));
+ for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
+ OP_REQUIRES(
+ c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
+ errors::InvalidArgument(
+ "The inner ", shape.dims() - ix,
+ " dimensions of output.shape=", shape.DebugString(),
+ " must match the inner ", updates.shape().dims() - outer_dims,
+ " dimensions of updates.shape=", updates.shape().DebugString()));
+ }
+ OP_REQUIRES(c, shape_input.dims() == 1,
+ errors::InvalidArgument("Shape must be a vector"));
+
Tensor out;
OP_REQUIRES_OK(
c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index b7477a768a..79fe927b8a 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -23,8 +23,11 @@ import functools
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
@@ -364,6 +367,15 @@ class ScatterNdTest(test.TestCase):
del input_ # input_ is not used in scatter_nd
return array_ops.scatter_nd(indices, updates, shape)
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvalidShape(self):
+ # TODO(apassos) figure out how to unify these errors
+ with self.assertRaises(errors.InvalidArgumentError
+ if context.executing_eagerly() else ValueError):
+ array_ops.scatter_nd(indices=[0], # this should be indices=[[0]]
+ updates=[0.0],
+ shape=[1])
+
def testString(self):
indices = constant_op.constant([[4], [3], [1], [7]],
dtype=dtypes.int32)