aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <irving@naml.us>2018-07-11 13:05:14 -0700
committerGravatar Geoffrey Irving <irving@naml.us>2018-07-18 12:07:02 -0700
commita10592ef7741d858466a980239fc95e65d7c66b6 (patch)
treec4d2d72d2e6eaf7477e0861618d24507dc36008c
parentf5a830421f287208a51bd04a94842913eb1fc0d2 (diff)
Improve error messages for gather_nd and scatter_nd
Use SliceDebugString to produce nice error messages using multidimensional indexes.
-rw-r--r--tensorflow/core/kernels/gather_nd_op.cc7
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc6
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_test.cc2
-rw-r--r--tensorflow/python/kernel_tests/gather_nd_op_test.py12
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py4
5 files changed, 15 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc
index 4e53291b7f..e50b7fe3bf 100644
--- a/tensorflow/core/kernels/gather_nd_op.cc
+++ b/tensorflow/core/kernels/gather_nd_op.cc
@@ -188,12 +188,13 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
// bad_i will only return >= 0 on CPUs right now.
if (bad_i >= 0) {
+ auto shape = indices.shape();
+ shape.RemoveLastDims(1);
return errors::InvalidArgument(
- "flat indices[", bad_i, ", :] = [",
+ "indices", SliceDebugString(shape, bad_i), " = [",
str_util::Join(
gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "),
- "] does not index into param (shape: ", params.shape().DebugString(),
- ").");
+ "] does not index into param shape ", params.shape().DebugString());
}
}
return Status::OK();
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e1fc2ea128..5f300fb64d 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -537,11 +537,13 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
}
}
if (bad_i >= 0) {
+ auto slice_shape = indices.shape();
+ slice_shape.RemoveLastDims(1);
return errors::InvalidArgument(
- "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), " = [",
+ "indices", SliceDebugString(slice_shape, bad_i), " = [",
str_util::Join(
gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
- "] does not index into ", shape.DebugString());
+ "] does not index into shape ", shape.DebugString());
}
return Status::OK();
}
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
index c134a8dd5b..95ecc69c95 100644
--- a/tensorflow/core/kernels/scatter_nd_op_test.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -185,7 +185,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(str_util::StrContains(
- s.ToString(), "Invalid indices: [2,0] = [99] does not index into [5,3]"))
+ s.ToString(), "indices[2] = [99] does not index into shape [5,3]"))
<< s;
}
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index 58e2a8ac2a..c0b419e1d1 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -203,8 +203,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [7]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[1, :\] = \[7\] does not index into param "
- r"\(shape: \[3\]\)"):
+ r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"):
gather_nd.eval()
def _disabledTestBadIndicesGPU(self):
@@ -217,8 +216,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [7]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[1, :\] = \[7\] does not index into param "
- r"\(shape: \[3\]\)"):
+ r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"):
gather_nd.eval()
def testBadIndicesWithSlicesCPU(self):
@@ -227,8 +225,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [0], [1]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[2, :\] = \[1\] does not index into param "
- r"\(shape: \[1,3\]\)"):
+ r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"):
gather_nd.eval()
def _disabledTestBadIndicesWithSlicesGPU(self):
@@ -241,8 +238,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [0], [1]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[2, :\] = \[1\] does not index into param "
- r"\(shape: \[1,3\]\)"):
+ r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"):
gather_nd.eval()
def testGradientsRank2Elements(self):
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index f9b9c77bbf..c31499e52d 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -268,12 +268,12 @@ class StatefulScatterNdTest(test.TestCase):
# Test some out of range errors.
indices = np.array([[-1], [0], [5]])
with self.assertRaisesOpError(
- r"Invalid indices: \[0,0\] = \[-1\] does not index into \[6\]"):
+ r"indices\[0\] = \[-1\] does not index into shape \[6\]"):
op(ref, indices, updates).eval()
indices = np.array([[2], [0], [6]])
with self.assertRaisesOpError(
- r"Invalid indices: \[2,0\] = \[6\] does not index into \[6\]"):
+ r"indices\[2\] = \[6\] does not index into shape \[6\]"):
op(ref, indices, updates).eval()
def testRank3ValidShape(self):