aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/gather_op_test.py
diff options
context:
space:
mode:
authorGravatar Robin Richtsfeld <robin.richtsfeld@gmail.com>2018-05-26 01:38:33 +0200
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-05-25 16:38:33 -0700
commit83116bafebb500fa963809599b7f2583367c92d6 (patch)
tree1fdaff5404a07c83daccaae9c88f68851fce2a36 /tensorflow/python/kernel_tests/gather_op_test.py
parent38926b8a0fa89bef74085be0e321c13e739795d4 (diff)
Fix of issue #13164 (Merges #13382) (#16368)
* tf.gather int64 GPU, tf.gather_nd int32/int64 GPU, tf.scatter_nd int32 GPU * Fix tf.gather test
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py20
1 files changed, 18 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index a2fcd751df..033fa95935 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -27,7 +27,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.platform import test
-_TEST_TYPES = (dtypes.float32, dtypes.complex64, dtypes.complex128)
+_TEST_TYPES = (dtypes.int64, dtypes.float32,
+ dtypes.complex64, dtypes.complex128)
class GatherTest(test.TestCase):
@@ -122,6 +123,9 @@ class GatherTest(test.TestCase):
gather, [tf_params, tf_indices, tf_axis], gather_grad)
self.assertEqual(indices_grad, None)
self.assertEqual(axis_grad, None)
+ if dtype.is_integer:
+ self.assertEqual(params_grad, None)
+ continue
# For axis 0, we are able to create an efficient IndexedSlices for
# the gradient.
if axis == 0:
@@ -177,7 +181,19 @@ class GatherTest(test.TestCase):
gather_t = array_ops.gather(params, indices, axis=axis)
self.assertEqual(None, gather_t.shape)
- def testBadIndices(self):
+ def testBadIndicesCPU(self):
+ with self.test_session(use_gpu=False):
+ params = [[0, 1, 2], [3, 4, 5]]
+ with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
+ array_ops.gather(params, [[7]], axis=0).eval()
+ with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
+ array_ops.gather(params, [[7]], axis=1).eval()
+
+ def _disabledTestBadIndicesGPU(self):
+ # TODO disabled due to different behavior on GPU and CPU
+ # On GPU the bad indices do not raise error but fetch 0 values
+ if not test.is_gpu_available():
+ return
with self.test_session(use_gpu=True):
params = [[0, 1, 2], [3, 4, 5]]
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):