aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/gather_op_test.py
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-08-18 21:45:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-18 23:03:46 -0700
commitd35ba035a09d60838ae6344ba251202a9a78c47e (patch)
treef4f20c26006c0b39e8e71a040a2d6d935d390831 /tensorflow/python/kernel_tests/gather_op_test.py
parent37000ef3b5a63a8cf9b6e8fd3dd8059aba0e6ddc (diff)
Fix gather gradient for empty slices
We can't use -1's in the reshape if the size is 0, since the -1 would be ambiguous. Also simplify the code now that array_ops.shape does its own checking for fully defined shapes. A previous version used too much colocate_with and broke a distributed model. This version uses colocate_with only for params following the previous version of the code. Change: 130720593
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py43
1 files changed, 23 insertions, 20 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 3bd4fe8b72..272f28d437 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -57,26 +57,29 @@ class GatherTest(tf.test.TestCase):
def testHigherRank(self):
np.random.seed(1)
- shape = (4, 3, 2)
- params = np.random.randn(*shape)
- indices = np.random.randint(shape[0], size=15).reshape(3, 5)
- with self.test_session(use_gpu=self.use_gpu):
- tf_params = tf.constant(params)
- tf_indices = tf.constant(indices)
- gather = tf.gather(tf_params, tf_indices)
- self.assertAllEqual(params[indices], gather.eval())
- self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
- # Test gradients
- gather_grad = np.random.randn(*gather.get_shape().as_list())
- params_grad, indices_grad = tf.gradients(gather, [tf_params, tf_indices],
- gather_grad)
- self.assertEqual(indices_grad, None)
- self.assertEqual(type(params_grad), tf.IndexedSlices)
- params_grad = tf.convert_to_tensor(params_grad)
- correct_params_grad = np.zeros(shape)
- for i, g in zip(indices.ravel(), gather_grad.reshape((15,) + shape[1:])):
- correct_params_grad[i] += g
- self.assertAllClose(correct_params_grad, params_grad.eval())
+ # We check that scalar and empty shapes work as well
+ for shape in (7, 0), (4, 3, 2):
+ for indices_shape in (), (0,), (3, 0), (3, 5):
+ params = np.random.randn(*shape)
+ indices = np.random.randint(shape[0], size=indices_shape)
+ with self.test_session(use_gpu=self.use_gpu):
+ tf_params = tf.constant(params)
+ tf_indices = tf.constant(indices)
+ gather = tf.gather(tf_params, tf_indices)
+ self.assertAllEqual(params[indices], gather.eval())
+ self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
+ # Test gradients
+ gather_grad = np.random.randn(*gather.get_shape().as_list())
+ params_grad, indices_grad = tf.gradients(
+ gather, [tf_params, tf_indices], gather_grad)
+ self.assertEqual(indices_grad, None)
+ self.assertEqual(type(params_grad), tf.IndexedSlices)
+ params_grad = tf.convert_to_tensor(params_grad)
+ correct_params_grad = np.zeros(shape)
+ for i, g in zip(indices.flat,
+ gather_grad.reshape((indices.size,) + shape[1:])):
+ correct_params_grad[i] += g
+ self.assertAllClose(correct_params_grad, params_grad.eval())
def testUnknownIndices(self):
params = tf.constant([[0, 1, 2]])