diff options
author | 2016-08-17 12:16:54 -0800 | |
---|---|---|
committer | 2016-08-17 13:33:16 -0700 | |
commit | 7817445b6677f66bba04ec7f3c0836cd3b719011 (patch) | |
tree | 00efbc4d3770a827c2e29b53309302bd47fedaa9 /tensorflow/python/kernel_tests/gather_op_test.py | |
parent | 0fcaae1ebd5943b0622f9953dbe0fb22f8eafbfc (diff) |
Fix gather for nonempty indices, empty slices
tf.gather([0], [[]]) now works. This required
1. Avoiding an empty GPU kernel launch.
2. Avoiding the use of -1 reshapes in the gradient. This code was also
simplified a bit using the fact that array_ops.shape does it's own
checking for fully defined shapes.
Change: 130554783
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/gather_op_test.py | 52 |
1 files changed, 32 insertions, 20 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index c3ef9e491d..272f28d437 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -57,26 +57,29 @@ class GatherTest(tf.test.TestCase): def testHigherRank(self): np.random.seed(1) - shape = (4, 3, 2) - params = np.random.randn(*shape) - indices = np.random.randint(shape[0], size=15).reshape(3, 5) - with self.test_session(use_gpu=self.use_gpu): - tf_params = tf.constant(params) - tf_indices = tf.constant(indices) - gather = tf.gather(tf_params, tf_indices) - self.assertAllEqual(params[indices], gather.eval()) - self.assertEqual(indices.shape + params.shape[1:], gather.get_shape()) - # Test gradients - gather_grad = np.random.randn(*gather.get_shape().as_list()) - params_grad, indices_grad = tf.gradients(gather, [tf_params, tf_indices], - gather_grad) - self.assertEqual(indices_grad, None) - self.assertEqual(type(params_grad), tf.IndexedSlices) - params_grad = tf.convert_to_tensor(params_grad) - correct_params_grad = np.zeros(shape) - for i, g in zip(indices.ravel(), gather_grad.reshape((15,) + shape[1:])): - correct_params_grad[i] += g - self.assertAllClose(correct_params_grad, params_grad.eval()) + # We check that scalar and empty shapes work as well + for shape in (7, 0), (4, 3, 2): + for indices_shape in (), (0,), (3, 0), (3, 5): + params = np.random.randn(*shape) + indices = np.random.randint(shape[0], size=indices_shape) + with self.test_session(use_gpu=self.use_gpu): + tf_params = tf.constant(params) + tf_indices = tf.constant(indices) + gather = tf.gather(tf_params, tf_indices) + self.assertAllEqual(params[indices], gather.eval()) + self.assertEqual(indices.shape + params.shape[1:], gather.get_shape()) + # Test gradients + gather_grad = np.random.randn(*gather.get_shape().as_list()) + params_grad, indices_grad = tf.gradients( + gather, [tf_params, tf_indices], gather_grad) + self.assertEqual(indices_grad, None) + self.assertEqual(type(params_grad), tf.IndexedSlices) + params_grad = tf.convert_to_tensor(params_grad) + correct_params_grad = np.zeros(shape) + for i, g in zip(indices.flat, + gather_grad.reshape((indices.size,) + shape[1:])): + correct_params_grad[i] += g + self.assertAllClose(correct_params_grad, params_grad.eval()) def testUnknownIndices(self): params = tf.constant([[0, 1, 2]]) @@ -92,6 +95,15 @@ class GatherTest(tf.test.TestCase): with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): gather.eval() + def testEmptySlices(self): + with self.test_session(use_gpu=self.use_gpu): + for dtype in np.float32, np.float64: + for itype in np.int32, np.int64: + params = np.zeros((7, 0), dtype=dtype) + indices = np.array([3, 4], dtype=itype) + gather = tf.gather(params, indices) + self.assertAllEqual(gather.eval(), np.zeros((2, 0))) + class GatherGpuTest(GatherTest): use_gpu = True |