diff options
author | Geoffrey Irving <irving@naml.us> | 2018-07-11 13:05:14 -0700 |
---|---|---|
committer | Geoffrey Irving <irving@naml.us> | 2018-07-18 12:07:02 -0700 |
commit | a10592ef7741d858466a980239fc95e65d7c66b6 (patch) | |
tree | c4d2d72d2e6eaf7477e0861618d24507dc36008c | |
parent | f5a830421f287208a51bd04a94842913eb1fc0d2 (diff) |
Improve error messages for gather_nd and scatter_nd
Use SliceDebugString to produce nice error messages using
multidimensional indexes.
-rw-r--r-- | tensorflow/core/kernels/gather_nd_op.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/gather_nd_op_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/scatter_nd_ops_test.py | 4 |
5 files changed, 15 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc index 4e53291b7f..e50b7fe3bf 100644 --- a/tensorflow/core/kernels/gather_nd_op.cc +++ b/tensorflow/core/kernels/gather_nd_op.cc @@ -188,12 +188,13 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params, // bad_i will only return >= 0 on CPUs right now. if (bad_i >= 0) { + auto shape = indices.shape(); + shape.RemoveLastDims(1); return errors::InvalidArgument( - "flat indices[", bad_i, ", :] = [", + "indices", SliceDebugString(shape, bad_i), " = [", str_util::Join( gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "), - "] does not index into param (shape: ", params.shape().DebugString(), - ")."); + "] does not index into param shape ", params.shape().DebugString()); } } return Status::OK(); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index e1fc2ea128..5f300fb64d 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -537,11 +537,13 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, } } if (bad_i >= 0) { + auto slice_shape = indices.shape(); + slice_shape.RemoveLastDims(1); return errors::InvalidArgument( - "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), " = [", + "indices", SliceDebugString(slice_shape, bad_i), " = [", str_util::Join( gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "), - "] does not index into ", shape.DebugString()); + "] does not index into shape ", shape.DebugString()); } return Status::OK(); } diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index c134a8dd5b..95ecc69c95 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -185,7 +185,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) { {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); Status s = RunOpKernel(); EXPECT_TRUE(str_util::StrContains( - s.ToString(), "Invalid indices: [2,0] = [99] does not index into [5,3]")) + s.ToString(), "indices[2] = [99] does not index into shape [5,3]")) << s; } diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py index 58e2a8ac2a..c0b419e1d1 100644 --- a/tensorflow/python/kernel_tests/gather_nd_op_test.py +++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py @@ -203,8 +203,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [7]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[1, :\] = \[7\] does not index into param " - r"\(shape: \[3\]\)"): + r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"): gather_nd.eval() def _disabledTestBadIndicesGPU(self): @@ -217,8 +216,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [7]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[1, :\] = \[7\] does not index into param " - r"\(shape: \[3\]\)"): + r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"): gather_nd.eval() def testBadIndicesWithSlicesCPU(self): @@ -227,8 +225,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [0], [1]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[2, :\] = \[1\] does not index into param " - r"\(shape: \[1,3\]\)"): + r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"): gather_nd.eval() def _disabledTestBadIndicesWithSlicesGPU(self): @@ -241,8 +238,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [0], [1]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[2, :\] = \[1\] does not index into param " - r"\(shape: \[1,3\]\)"): + r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"): gather_nd.eval() def testGradientsRank2Elements(self): diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index f9b9c77bbf..c31499e52d 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -268,12 +268,12 @@ class StatefulScatterNdTest(test.TestCase): # Test some out of range errors. indices = np.array([[-1], [0], [5]]) with self.assertRaisesOpError( - r"Invalid indices: \[0,0\] = \[-1\] does not index into \[6\]"): + r"indices\[0\] = \[-1\] does not index into shape \[6\]"): op(ref, indices, updates).eval() indices = np.array([[2], [0], [6]]) with self.assertRaisesOpError( - r"Invalid indices: \[2,0\] = \[6\] does not index into \[6\]"): + r"indices\[2\] = \[6\] does not index into shape \[6\]"): op(ref, indices, updates).eval() def testRank3ValidShape(self): |