diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_nd_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/gather_nd_op_test.py | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py index 91ebe8de99..58e2a8ac2a 100644 --- a/tensorflow/python/kernel_tests/gather_nd_op_test.py +++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py @@ -197,7 +197,21 @@ class GatherNdTest(test.TestCase): self.assertEqual(None, shape.ndims) self.assertEqual(None, shape[0].value) - def testBadIndices(self): + def testBadIndicesCPU(self): + with self.test_session(use_gpu=False): + params = [0, 1, 2] + indices = [[[0], [7]]] # Make this one higher rank + gather_nd = array_ops.gather_nd(params, indices) + with self.assertRaisesOpError( + r"flat indices\[1, :\] = \[7\] does not index into param " + r"\(shape: \[3\]\)"): + gather_nd.eval() + + def _disabledTestBadIndicesGPU(self): + # TODO disabled due to different behavior on GPU and CPU + # On GPU the bad indices do not raise error but fetch 0 values + if not test.is_gpu_available(): + return with self.test_session(use_gpu=True): params = [0, 1, 2] indices = [[[0], [7]]] # Make this one higher rank @@ -207,7 +221,21 @@ class GatherNdTest(test.TestCase): r"\(shape: \[3\]\)"): gather_nd.eval() - def testBadIndicesWithSlices(self): + def testBadIndicesWithSlicesCPU(self): + with self.test_session(use_gpu=False): + params = [[0, 1, 2]] + indices = [[[0], [0], [1]]] # Make this one higher rank + gather_nd = array_ops.gather_nd(params, indices) + with self.assertRaisesOpError( + r"flat indices\[2, :\] = \[1\] does not index into param " + r"\(shape: \[1,3\]\)"): + gather_nd.eval() + + def _disabledTestBadIndicesWithSlicesGPU(self): + # TODO disabled due to different behavior on GPU and CPU + # On GPU the bad indices do not raise error but fetch 0 values + if not test.is_gpu_available(): + return with self.test_session(use_gpu=True): params = [[0, 1, 2]] indices = [[[0], [0], [1]]] # Make this one higher rank |