aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/gather_nd_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_nd_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/gather_nd_op_test.py32
1 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index 91ebe8de99..58e2a8ac2a 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -197,7 +197,21 @@ class GatherNdTest(test.TestCase):
self.assertEqual(None, shape.ndims)
self.assertEqual(None, shape[0].value)
- def testBadIndices(self):
+ def testBadIndicesCPU(self):
+ with self.test_session(use_gpu=False):
+ params = [0, 1, 2]
+ indices = [[[0], [7]]] # Make this one higher rank
+ gather_nd = array_ops.gather_nd(params, indices)
+ with self.assertRaisesOpError(
+ r"flat indices\[1, :\] = \[7\] does not index into param "
+ r"\(shape: \[3\]\)"):
+ gather_nd.eval()
+
+ def _disabledTestBadIndicesGPU(self):
+ # TODO disabled due to different behavior on GPU and CPU
+ # On GPU the bad indices do not raise error but fetch 0 values
+ if not test.is_gpu_available():
+ return
with self.test_session(use_gpu=True):
params = [0, 1, 2]
indices = [[[0], [7]]] # Make this one higher rank
@@ -207,7 +221,21 @@ class GatherNdTest(test.TestCase):
r"\(shape: \[3\]\)"):
gather_nd.eval()
- def testBadIndicesWithSlices(self):
+ def testBadIndicesWithSlicesCPU(self):
+ with self.test_session(use_gpu=False):
+ params = [[0, 1, 2]]
+ indices = [[[0], [0], [1]]] # Make this one higher rank
+ gather_nd = array_ops.gather_nd(params, indices)
+ with self.assertRaisesOpError(
+ r"flat indices\[2, :\] = \[1\] does not index into param "
+ r"\(shape: \[1,3\]\)"):
+ gather_nd.eval()
+
+ def _disabledTestBadIndicesWithSlicesGPU(self):
+ # TODO disabled due to different behavior on GPU and CPU
+ # On GPU the bad indices do not raise error but fetch 0 values
+ if not test.is_gpu_available():
+ return
with self.test_session(use_gpu=True):
params = [[0, 1, 2]]
indices = [[[0], [0], [1]]] # Make this one higher rank