diff options
author | 2018-05-26 01:38:33 +0200 | |
---|---|---|
committer | 2018-05-25 16:38:33 -0700 | |
commit | 83116bafebb500fa963809599b7f2583367c92d6 (patch) | |
tree | 1fdaff5404a07c83daccaae9c88f68851fce2a36 /tensorflow/python/kernel_tests/gather_op_test.py | |
parent | 38926b8a0fa89bef74085be0e321c13e739795d4 (diff) |
Fix of issue #13164 (Merges #13382) (#16368)
* tf.gather int64 GPU, tf.gather_nd int32/int64 GPU, tf.scatter_nd int32 GPU
* Fix tf.gather test
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/gather_op_test.py | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index a2fcd751df..033fa95935 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -27,7 +27,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.platform import test -_TEST_TYPES = (dtypes.float32, dtypes.complex64, dtypes.complex128) +_TEST_TYPES = (dtypes.int64, dtypes.float32, + dtypes.complex64, dtypes.complex128) class GatherTest(test.TestCase): @@ -122,6 +123,9 @@ class GatherTest(test.TestCase): gather, [tf_params, tf_indices, tf_axis], gather_grad) self.assertEqual(indices_grad, None) self.assertEqual(axis_grad, None) + if dtype.is_integer: + self.assertEqual(params_grad, None) + continue # For axis 0, we are able to create an efficient IndexedSlices for # the gradient. if axis == 0: @@ -177,7 +181,19 @@ class GatherTest(test.TestCase): gather_t = array_ops.gather(params, indices, axis=axis) self.assertEqual(None, gather_t.shape) - def testBadIndices(self): + def testBadIndicesCPU(self): + with self.test_session(use_gpu=False): + params = [[0, 1, 2], [3, 4, 5]] + with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): + array_ops.gather(params, [[7]], axis=0).eval() + with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): + array_ops.gather(params, [[7]], axis=1).eval() + + def _disabledTestBadIndicesGPU(self): + # TODO disabled due to different behavior on GPU and CPU + # On GPU the bad indices do not raise error but fetch 0 values + if not test.is_gpu_available(): + return with self.test_session(use_gpu=True): params = [[0, 1, 2], [3, 4, 5]] with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): |