aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/gather_op_test.py
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-08-17 12:16:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-17 13:33:16 -0700
commit7817445b6677f66bba04ec7f3c0836cd3b719011 (patch)
tree00efbc4d3770a827c2e29b53309302bd47fedaa9 /tensorflow/python/kernel_tests/gather_op_test.py
parent0fcaae1ebd5943b0622f9953dbe0fb22f8eafbfc (diff)
Fix gather for nonempty indices, empty slices
tf.gather([0], [[]]) now works. This required 1. Avoiding an empty GPU kernel launch. 2. Avoiding the use of -1 reshapes in the gradient. This code was also simplified a bit using the fact that array_ops.shape does it's own checking for fully defined shapes. Change: 130554783
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py52
1 files changed, 32 insertions, 20 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index c3ef9e491d..272f28d437 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -57,26 +57,29 @@ class GatherTest(tf.test.TestCase):
def testHigherRank(self):
np.random.seed(1)
- shape = (4, 3, 2)
- params = np.random.randn(*shape)
- indices = np.random.randint(shape[0], size=15).reshape(3, 5)
- with self.test_session(use_gpu=self.use_gpu):
- tf_params = tf.constant(params)
- tf_indices = tf.constant(indices)
- gather = tf.gather(tf_params, tf_indices)
- self.assertAllEqual(params[indices], gather.eval())
- self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
- # Test gradients
- gather_grad = np.random.randn(*gather.get_shape().as_list())
- params_grad, indices_grad = tf.gradients(gather, [tf_params, tf_indices],
- gather_grad)
- self.assertEqual(indices_grad, None)
- self.assertEqual(type(params_grad), tf.IndexedSlices)
- params_grad = tf.convert_to_tensor(params_grad)
- correct_params_grad = np.zeros(shape)
- for i, g in zip(indices.ravel(), gather_grad.reshape((15,) + shape[1:])):
- correct_params_grad[i] += g
- self.assertAllClose(correct_params_grad, params_grad.eval())
+ # We check that scalar and empty shapes work as well
+ for shape in (7, 0), (4, 3, 2):
+ for indices_shape in (), (0,), (3, 0), (3, 5):
+ params = np.random.randn(*shape)
+ indices = np.random.randint(shape[0], size=indices_shape)
+ with self.test_session(use_gpu=self.use_gpu):
+ tf_params = tf.constant(params)
+ tf_indices = tf.constant(indices)
+ gather = tf.gather(tf_params, tf_indices)
+ self.assertAllEqual(params[indices], gather.eval())
+ self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
+ # Test gradients
+ gather_grad = np.random.randn(*gather.get_shape().as_list())
+ params_grad, indices_grad = tf.gradients(
+ gather, [tf_params, tf_indices], gather_grad)
+ self.assertEqual(indices_grad, None)
+ self.assertEqual(type(params_grad), tf.IndexedSlices)
+ params_grad = tf.convert_to_tensor(params_grad)
+ correct_params_grad = np.zeros(shape)
+ for i, g in zip(indices.flat,
+ gather_grad.reshape((indices.size,) + shape[1:])):
+ correct_params_grad[i] += g
+ self.assertAllClose(correct_params_grad, params_grad.eval())
def testUnknownIndices(self):
params = tf.constant([[0, 1, 2]])
@@ -92,6 +95,15 @@ class GatherTest(tf.test.TestCase):
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
gather.eval()
+ def testEmptySlices(self):
+ with self.test_session(use_gpu=self.use_gpu):
+ for dtype in np.float32, np.float64:
+ for itype in np.int32, np.int64:
+ params = np.zeros((7, 0), dtype=dtype)
+ indices = np.array([3, 4], dtype=itype)
+ gather = tf.gather(params, indices)
+ self.assertAllEqual(gather.eval(), np.zeros((2, 0)))
+
class GatherGpuTest(GatherTest):
use_gpu = True