diff options
author | 2017-07-11 09:51:54 -0700 | |
---|---|---|
committer | 2017-07-11 09:55:52 -0700 | |
commit | b1f9e2c89eb007cb4b9483d08dcace1e45e84164 (patch) | |
tree | b2b82fc0bd6abf3b77a412a251fcfdacf70a21dc /tensorflow/python/kernel_tests/gather_op_test.py | |
parent | 18a5510e67ef536c947512b70030c5c995ce7875 (diff) |
Add an axis parameter to tf.gather. Fixes GitHub issue #11223.
This brings tf.gather closer to compatibility with numpy.take.
To emulate gathering over an axis generally requires inefficient workarounds, e.g. transpose/gather/transpose. This technique is gaining popularity (hundreds of uses inside and outside of Google), so it is worth supporting efficiently.
For an `[a_0, ..., a_i, ..., a_n]` tensor, gathering `N` elements from axis `i` requires `(a_0*...*a_i-1) * N` copies of `(a_i+1 * ... * a_n)` elements each. The CPU kernel does this with memcpy which is far more efficient than transpose/gather/transpose since it requires no intermediate allocations and copies. The GPU kernel does the same number of copies but in parallel across multiple hardware threads.
Since this is a backwards incompatible change, this adds a "GatherV2" op with an axis input, and simultaneously supports backwards compatibility with "Gather" ops by defaulting to axis 0 if a 3rd input is not present.
PiperOrigin-RevId: 161541416
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/gather_op_test.py | 192 |
1 files changed, 136 insertions, 56 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index b3ce234d4e..04d65b88a1 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -44,70 +44,110 @@ class GatherTest(test.TestCase): with self.test_session(use_gpu=True): data = np.array([0, 1, 2, 3, 7, 5]) for dtype in _TEST_TYPES: - params_np = self._buildParams(data, dtype) - params = constant_op.constant(params_np) - indices = constant_op.constant(4) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual(params_np[4], gather_val) - self.assertEqual([], gather_t.get_shape()) + for indices in 4, [1, 2, 2, 4, 5]: + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices_tf = constant_op.constant(indices) + gather_t = array_ops.gather(params, indices_tf) + gather_val = gather_t.eval() + np_val = params_np[indices] + self.assertAllEqual(np_val, gather_val) + self.assertEqual(np_val.shape, gather_t.get_shape()) def testScalar2D(self): with self.test_session(use_gpu=True): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in _TEST_TYPES: - params_np = self._buildParams(data, dtype) - params = constant_op.constant(params_np) - indices = constant_op.constant(2) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual(params_np[2], gather_val) - self.assertEqual([3], gather_t.get_shape()) + for axis in range(data.ndim): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices = constant_op.constant(2) + gather_t = array_ops.gather(params, indices, axis=axis) + gather_val = gather_t.eval() + self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val) + expected_shape = data.shape[:axis] + data.shape[axis + 1:] + self.assertEqual(expected_shape, gather_t.get_shape()) def testSimpleTwoD32(self): with self.test_session(use_gpu=True): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in _TEST_TYPES: - params_np = self._buildParams(data, dtype) - params = constant_op.constant(params_np) - indices = constant_op.constant([0, 4, 0, 2]) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual(params_np[[0, 4, 0, 2]], gather_val) - self.assertEqual([4, 3], gather_t.get_shape()) + for axis in range(data.ndim): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + # The indices must be in bounds for any axis. + indices = constant_op.constant([0, 1, 0, 2]) + gather_t = array_ops.gather(params, indices, axis=axis) + gather_val = gather_t.eval() + self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis), + gather_val) + expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:] + self.assertEqual(expected_shape, gather_t.get_shape()) def testHigherRank(self): - np.random.seed(1) - # We check that scalar and empty shapes work as well - for shape in (7, 0), (4, 3, 2): - for indices_shape in (), (0,), (3, 0), (3, 5): + # We check that scalar and empty indices shapes work as well + for shape in (4, 3, 2), (2, 1, 3, 2): + for indices_shape in (), (0,), (3, 0), (3, 5), (5, 2, 3): for dtype in _TEST_TYPES: - params = self._buildParams(np.random.randn(*shape), dtype) - indices = np.random.randint(shape[0], size=indices_shape) - with self.test_session(use_gpu=True): - tf_params = constant_op.constant(params) - tf_indices = constant_op.constant(indices) - gather = array_ops.gather(tf_params, tf_indices) - self.assertAllEqual(params[indices], gather.eval()) - self.assertEqual(indices.shape + params.shape[1:], - gather.get_shape()) - # Test gradients - gather_grad = np.random.randn(*gather.get_shape().as_list()).astype( - dtype.as_numpy_dtype) - if dtype.is_complex: - gather_grad -= 1j * gather_grad - params_grad, indices_grad = gradients_impl.gradients( - gather, [tf_params, tf_indices], gather_grad) - self.assertEqual(indices_grad, None) - self.assertEqual(type(params_grad), ops.IndexedSlices) - params_grad = ops.convert_to_tensor(params_grad) - correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) - for i, g in zip(indices.flat, - gather_grad.reshape((indices.size,) + shape[1:])): - correct_params_grad[i] += g - self.assertAllClose(correct_params_grad, params_grad.eval()) + for axis in range(len(shape)): + params = self._buildParams(np.random.randn(*shape), dtype) + indices = np.random.randint(shape[axis], size=indices_shape) + with self.test_session(use_gpu=True) as sess: + tf_params = constant_op.constant(params) + tf_indices = constant_op.constant(indices) + # Check that both positive and negative indices for axis work. + tf_axis = constant_op.constant(axis) + tf_negative_axis = constant_op.constant(-len(shape) + axis) + gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) + gather_negative_axis = array_ops.gather( + tf_params, tf_indices, axis=tf_negative_axis) + gather_value, gather_negative_axis_value = sess.run( + [gather, gather_negative_axis]) + gather_np = np.take(params, indices, axis) + self.assertAllEqual(gather_np, gather_value) + self.assertAllEqual(gather_np, gather_negative_axis_value) + expected_shape = (params.shape[:axis] + indices.shape + + params.shape[axis + 1:]) + self.assertEqual(expected_shape, gather.shape) + self.assertEqual(expected_shape, gather_negative_axis.shape) + + # Test gradients + gather_grad = np.random.randn( + *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype) + if dtype.is_complex: + gather_grad -= 1j * gather_grad + params_grad, indices_grad, axis_grad = gradients_impl.gradients( + gather, [tf_params, tf_indices, tf_axis], gather_grad) + self.assertEqual(indices_grad, None) + self.assertEqual(axis_grad, None) + # For axis 0, we are able to create an efficient IndexedSlices for + # the gradient. + if axis == 0: + self.assertEqual(type(params_grad), ops.IndexedSlices) + params_grad = ops.convert_to_tensor(params_grad) + correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) + outer_dims = axis + inner_dims = len(shape) - axis - 1 + gather_grad = gather_grad.reshape( + shape[:axis] + (indices.size,) + shape[axis + 1:]) + for source_index, dest_index in enumerate(indices.flat): + dest_slice = ((slice(None),) * outer_dims + (dest_index,) + + (slice(None),) * inner_dims) + source_slice = ((slice(None),) * outer_dims + (source_index,) + + (slice(None),) * inner_dims) + correct_params_grad[dest_slice] += gather_grad[source_slice] + self.assertAllClose(correct_params_grad, params_grad.eval(), + atol=2e-6, rtol=2e-6) + + def testString(self): + params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) + with self.test_session(): + self.assertAllEqual([b"qwer", b"uiop"], + array_ops.gather(params, 1, axis=0).eval()) + self.assertAllEqual([b"asdf", b"qwer"], + array_ops.gather(params, 0, axis=1).eval()) def testUnknownIndices(self): params = constant_op.constant([[0, 1, 2]]) @@ -115,22 +155,62 @@ class GatherTest(test.TestCase): gather_t = array_ops.gather(params, indices) self.assertEqual(None, gather_t.get_shape()) + def testUnknownAxis(self): + params = constant_op.constant([[0, 1, 2]]) + indices = constant_op.constant([[0, 0], [0, 0]]) + axis = array_ops.placeholder(dtypes.int32) + gather_t = array_ops.gather(params, indices, axis=axis) + # Rank 2 params with rank 2 indices results in a rank 3 shape. + self.assertEqual([None, None, None], gather_t.shape.as_list()) + + # If indices is also unknown the result rank is unknown. + indices = array_ops.placeholder(dtypes.int32) + gather_t = array_ops.gather(params, indices, axis=axis) + self.assertEqual(None, gather_t.shape) + def testBadIndices(self): with self.test_session(use_gpu=True): - params = [0, 1, 2] - indices = [[7]] - gather = array_ops.gather(params, indices) + params = [[0, 1, 2], [3, 4, 5]] + with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): + array_ops.gather(params, [[7]], axis=0).eval() with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): - gather.eval() + array_ops.gather(params, [[7]], axis=1).eval() + + def testBadAxis(self): + with self.test_session(use_gpu=True): + params = [0, 1, 2] + params_ph = array_ops.placeholder(dtypes.int32) + indices = 0 + for bad_axis in (1, 2, -2): + # Shape inference can validate axis for known params rank. + with self.assertRaisesWithPredicateMatch( + ValueError, "Shape must be at least rank . but is rank 1"): + array_ops.gather(params, indices, axis=bad_axis) + # If params rank is unknown, an op error occurs. + with self.assertRaisesOpError( + r"Expected axis in the range \[-1, 1\), but got %s" % bad_axis): + array_ops.gather(params_ph, indices, axis=bad_axis).eval( + feed_dict={params_ph: params}) def testEmptySlices(self): with self.test_session(use_gpu=True): for dtype in _TEST_TYPES: for itype in np.int32, np.int64: - params = np.zeros((7, 0), dtype=dtype.as_numpy_dtype) + # Leading axis gather. + params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype) indices = np.array([3, 4], dtype=itype) - gather = array_ops.gather(params, indices) - self.assertAllEqual(gather.eval(), np.zeros((2, 0))) + gather = array_ops.gather(params, indices, axis=0) + self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0))) + + # Middle axis gather. + params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype) + gather = array_ops.gather(params, indices, axis=1) + self.assertAllEqual(gather.eval(), np.zeros((0, 2, 0))) + + # Trailing axis gather. + params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype) + gather = array_ops.gather(params, indices, axis=2) + self.assertAllEqual(gather.eval(), np.zeros((0, 0, 2))) if __name__ == "__main__": |