diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2016-08-18 21:45:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-18 23:03:46 -0700 |
commit | d35ba035a09d60838ae6344ba251202a9a78c47e (patch) | |
tree | f4f20c26006c0b39e8e71a040a2d6d935d390831 /tensorflow/python/kernel_tests/gather_op_test.py | |
parent | 37000ef3b5a63a8cf9b6e8fd3dd8059aba0e6ddc (diff) |
Fix gather gradient for empty slices
We can't use -1's in the reshape if the size is 0, since the -1 would be
ambiguous. Also simplify the code now that array_ops.shape does its own
checking for fully defined shapes.
A previous version used too much colocate_with and broke a distributed
model. This version uses colocate_with only for params following the
previous version of the code.
Change: 130720593
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/gather_op_test.py | 43 |
1 files changed, 23 insertions, 20 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index 3bd4fe8b72..272f28d437 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -57,26 +57,29 @@ class GatherTest(tf.test.TestCase): def testHigherRank(self): np.random.seed(1) - shape = (4, 3, 2) - params = np.random.randn(*shape) - indices = np.random.randint(shape[0], size=15).reshape(3, 5) - with self.test_session(use_gpu=self.use_gpu): - tf_params = tf.constant(params) - tf_indices = tf.constant(indices) - gather = tf.gather(tf_params, tf_indices) - self.assertAllEqual(params[indices], gather.eval()) - self.assertEqual(indices.shape + params.shape[1:], gather.get_shape()) - # Test gradients - gather_grad = np.random.randn(*gather.get_shape().as_list()) - params_grad, indices_grad = tf.gradients(gather, [tf_params, tf_indices], - gather_grad) - self.assertEqual(indices_grad, None) - self.assertEqual(type(params_grad), tf.IndexedSlices) - params_grad = tf.convert_to_tensor(params_grad) - correct_params_grad = np.zeros(shape) - for i, g in zip(indices.ravel(), gather_grad.reshape((15,) + shape[1:])): - correct_params_grad[i] += g - self.assertAllClose(correct_params_grad, params_grad.eval()) + # We check that scalar and empty shapes work as well + for shape in (7, 0), (4, 3, 2): + for indices_shape in (), (0,), (3, 0), (3, 5): + params = np.random.randn(*shape) + indices = np.random.randint(shape[0], size=indices_shape) + with self.test_session(use_gpu=self.use_gpu): + tf_params = tf.constant(params) + tf_indices = tf.constant(indices) + gather = tf.gather(tf_params, tf_indices) + self.assertAllEqual(params[indices], gather.eval()) + self.assertEqual(indices.shape + params.shape[1:], gather.get_shape()) + # Test gradients + gather_grad = np.random.randn(*gather.get_shape().as_list()) + params_grad, indices_grad = tf.gradients( + gather, [tf_params, tf_indices], gather_grad) + self.assertEqual(indices_grad, None) + self.assertEqual(type(params_grad), tf.IndexedSlices) + params_grad = tf.convert_to_tensor(params_grad) + correct_params_grad = np.zeros(shape) + for i, g in zip(indices.flat, + gather_grad.reshape((indices.size,) + shape[1:])): + correct_params_grad[i] += g + self.assertAllClose(correct_params_grad, params_grad.eval()) def testUnknownIndices(self): params = tf.constant([[0, 1, 2]]) |