aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/gather_nd_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_nd_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/gather_nd_op_test.py12
1 files changed, 4 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index 58e2a8ac2a..c0b419e1d1 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -203,8 +203,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [7]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[1, :\] = \[7\] does not index into param "
- r"\(shape: \[3\]\)"):
+ r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"):
gather_nd.eval()
def _disabledTestBadIndicesGPU(self):
@@ -217,8 +216,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [7]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[1, :\] = \[7\] does not index into param "
- r"\(shape: \[3\]\)"):
+ r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"):
gather_nd.eval()
def testBadIndicesWithSlicesCPU(self):
@@ -227,8 +225,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [0], [1]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[2, :\] = \[1\] does not index into param "
- r"\(shape: \[1,3\]\)"):
+ r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"):
gather_nd.eval()
def _disabledTestBadIndicesWithSlicesGPU(self):
@@ -241,8 +238,7 @@ class GatherNdTest(test.TestCase):
indices = [[[0], [0], [1]]] # Make this one higher rank
gather_nd = array_ops.gather_nd(params, indices)
with self.assertRaisesOpError(
- r"flat indices\[2, :\] = \[1\] does not index into param "
- r"\(shape: \[1,3\]\)"):
+ r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"):
gather_nd.eval()
def testGradientsRank2Elements(self):