diff options
author | 2016-06-23 15:00:07 -0800 | |
---|---|---|
committer | 2016-06-23 16:03:23 -0700 | |
commit | eff93149a6dc8e6826898fd9f9c28c81e21c9836 (patch) | |
tree | 6836053bf76e1d0de8b3c3a2619bd826122ae7b4 /tensorflow/python | |
parent | 6f0dc9000e250f7aabf9bcdf7b2c5a33352b5b9b (diff) |
Add axis option to tf.unpack and tf.pack.
Change: 125729413
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/kernel_tests/pack_op_test.py | 57 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/unpack_op_test.py | 68 | ||||
-rw-r--r-- | tensorflow/python/ops/array_grad.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 72 |
4 files changed, 177 insertions, 26 deletions
diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py index 349def7181..5d7055824d 100644 --- a/tensorflow/python/kernel_tests/pack_op_test.py +++ b/tensorflow/python/kernel_tests/pack_op_test.py @@ -22,6 +22,14 @@ import numpy as np import tensorflow as tf +def np_split_sqeeze(array, axis): + axis_len = array.shape[axis] + return [ + np.squeeze(arr, axis=(axis,)) + for arr in np.split(array, axis_len, axis=axis) + ] + + class PackOpTest(tf.test.TestCase): def testSimple(self): @@ -61,7 +69,7 @@ class PackOpTest(tf.test.TestCase): b = tf.reshape(a, tf.pack([2, 3])) self.assertAllEqual(b.get_shape(), [2, 3]) - def testGradients(self): + def testGradientsAxis0(self): np.random.seed(7) for use_gpu in False, True: for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): @@ -74,6 +82,21 @@ class PackOpTest(tf.test.TestCase): err = tf.test.compute_gradient_error(xs, shapes, c, shape) self.assertLess(err, 1e-6) + def testGradientsAxis1(self): + np.random.seed(7) + for use_gpu in False, True: + for shape in (2, 3), (3, 2), (4, 3, 2): + data = np.random.randn(*shape) + shapes = [shape[1:]] * shape[0] + out_shape = list(shape[1:]) + out_shape.insert(1, shape[0]) + with self.test_session(use_gpu=use_gpu): + # TODO(irving): Remove list() once we handle maps correctly + xs = list(map(tf.constant, data)) + c = tf.pack(xs, axis=1) + err = tf.test.compute_gradient_error(xs, shapes, c, out_shape) + self.assertLess(err, 1e-6) + def testZeroSize(self): # Verify that pack doesn't crash for zero size inputs for use_gpu in False, True: @@ -83,6 +106,38 @@ class PackOpTest(tf.test.TestCase): p = tf.pack(list(x)).eval() self.assertAllEqual(p, x) + def testAxis0Default(self): + with self.test_session(): + t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])] + + packed = tf.pack(t).eval() + + self.assertAllEqual(packed, np.array([[1, 2, 3], [4, 5, 6]])) + + def testAgainstNumpy(self): + # For 1 to 5 dimensions. + for i in range(1, 6): + expected = np.random.random(np.random.permutation(i) + 1) + + # For all the possible axis to split it, including negative indices. + for j in range(-i, i): + test_arrays = np_split_sqeeze(expected, j) + + with self.test_session(): + actual = tf.pack(test_arrays, axis=j).eval() + + self.assertNDArrayNear(expected, actual, 1e-6) + + def testDimOutOfRange(self): + t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])] + with self.assertRaisesRegexp(ValueError, r"axis = 2 not in \[-2, 2\)"): + tf.unpack(t, axis=2) + + def testDimOutOfNegativeRange(self): + t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])] + with self.assertRaisesRegexp(ValueError, r"axis = -3 not in \[-2, 2\)"): + tf.unpack(t, axis=-3) + class AutomaticPackingTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py index 7cc6d31efb..0cb701db82 100644 --- a/tensorflow/python/kernel_tests/unpack_op_test.py +++ b/tensorflow/python/kernel_tests/unpack_op_test.py @@ -23,6 +23,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +def np_split_sqeeze(array, axis): + axis_len = array.shape[axis] + return [ + np.squeeze(arr, axis=(axis,)) + for arr in np.split(array, axis_len, axis=axis) + ] + + class UnpackOpTest(tf.test.TestCase): def testSimple(self): @@ -40,7 +48,7 @@ class UnpackOpTest(tf.test.TestCase): cs = [c.eval() for c in cs] self.assertAllEqual(cs, data) - def testGradients(self): + def testGradientsAxis0(self): for use_gpu in False, True: for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): data = np.random.randn(*shape) @@ -52,6 +60,19 @@ class UnpackOpTest(tf.test.TestCase): err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i]) self.assertLess(err, 1e-6) + def testGradientsAxis1(self): + for use_gpu in False, True: + for shape in (2, 3), (3, 2), (4, 3, 2): + data = np.random.randn(*shape) + out_shape = list(shape) + del out_shape[1] + for i in xrange(shape[1]): + with self.test_session(use_gpu=use_gpu): + x = tf.constant(data) + cs = tf.unpack(x, num=shape[1], axis=1) + err = tf.test.compute_gradient_error(x, shape, cs[i], out_shape) + self.assertLess(err, 1e-6) + def testInferNum(self): with self.test_session(): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): @@ -60,12 +81,55 @@ class UnpackOpTest(tf.test.TestCase): self.assertEqual(type(cs), list) self.assertEqual(len(cs), shape[0]) - def testCannotInferNum(self): + def testCannotInferNumFromUnknownShape(self): x = tf.placeholder(np.float32) with self.assertRaisesRegexp( ValueError, r'Cannot infer num from shape <unknown>'): tf.unpack(x) + def testUnknownShapeOkWithNum(self): + x = tf.placeholder(np.float32) + tf.unpack(x, num=2) + + def testCannotInferNumFromNoneShape(self): + x = tf.placeholder(np.float32, shape=(None,)) + with self.assertRaisesRegexp(ValueError, + r'Cannot infer num from shape \(\?,\)'): + tf.unpack(x) + + def testAgainstNumpy(self): + # For 1 to 5 dimensions. + for i in range(1, 6): + a = np.random.random(np.random.permutation(i) + 1) + + # For all the possible axis to split it, including negative indices. + for j in range(-i, i): + expected = np_split_sqeeze(a, j) + + with self.test_session() as sess: + actual = sess.run(tf.unpack(a, axis=j)) + + self.assertAllEqual(expected, actual) + + def testAxis0Default(self): + with self.test_session() as sess: + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + + unpacked = sess.run(tf.unpack(a)) + + self.assertEqual(len(unpacked), 2) + self.assertAllEqual(unpacked[0], [1, 2, 3]) + self.assertAllEqual(unpacked[1], [4, 5, 6]) + + def testAxisOutOfRange(self): + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'): + tf.unpack(a, axis=2) + + def testAxisOutOfNegativeRange(self): + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'): + tf.unpack(a, axis=-3) if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 0e85aaf80f..5ff7cb95f4 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -30,13 +30,13 @@ from tensorflow.python.ops import math_ops @ops.RegisterGradient("Pack") def _PackGrad(op, grad): """Gradient for pack op.""" - return array_ops.unpack(grad, num=op.get_attr("N")) + return array_ops.unpack(grad, num=op.get_attr("N"), axis=op.get_attr("axis")) @ops.RegisterGradient("Unpack") -def _UnpackGrad(_, *grads): +def _UnpackGrad(op, *grads): """Gradient for unpack op.""" - return array_ops.pack(grads) + return array_ops.pack(grads, axis=op.get_attr("axis")) @ops.RegisterGradient("Concat") diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index f7c1f22ffd..53e9216380 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -467,7 +467,7 @@ def strided_slice(input_, ops.Tensor._override_operator("__getitem__", _SliceHelper) -def pack(values, name="pack"): +def pack(values, axis=0, name="pack"): """Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor. Packs tensors in `values` into a tensor with rank one higher than each tensor @@ -480,17 +480,31 @@ def pack(values, name="pack"): Args: values: A list of `Tensor` objects with the same shape and type. + axis: An `int`. The axis to pack along. Defaults to the first dimension. + Supports negative indexes. name: A name for this operation (optional). Returns: output: A packed `Tensor` with the same type as `values`. + + Raises: + ValueError: If `axis` is out of the range [-(R+1), R+1). """ - try: - # If the input is a constant list, it can just be converted to a constant op - return ops.convert_to_tensor(values, name=name) - except (TypeError, ValueError): - # Input list contains non-constant tensors - return gen_array_ops._pack(values, name=name) + if axis == 0: + try: + # If the input is a constant list, it can be converted to a constant op + return ops.convert_to_tensor(values, name=name) + except (TypeError, ValueError): + pass # Input list contains non-constant tensors + + value_shape = ops.convert_to_tensor(values[0], name=name).get_shape() + if value_shape.ndims is not None: + expanded_num_dims = value_shape.ndims + 1 + if axis < -expanded_num_dims or axis >= expanded_num_dims: + raise ValueError("axis = %d not in [%d, %d)" % + (axis, -expanded_num_dims, expanded_num_dims)) + + return gen_array_ops._pack(values, axis=axis, name=name) # pylint: disable=invalid-name @@ -581,12 +595,12 @@ ops.register_tensor_conversion_function( (list, tuple), _autopacking_conversion_function, 99) -def unpack(value, num=None, name="unpack"): - """Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors. +def unpack(value, num=None, axis=0, name="unpack"): + """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. - Unpacks `num` tensors from `value` along the first dimension. + Unpacks `num` tensors from `value` along the given dimension. If `num` is not specified (the default), it is inferred from `value`'s shape. - If `value.shape[0]` is not known, `ValueError` is raised. + If `value.shape[axis]` is not known, `ValueError` is raised. The ith tensor in `output` is the slice `value[i, ...]`. Each tensor in `output` has shape `value.shape[1:]`. @@ -597,8 +611,10 @@ def unpack(value, num=None, name="unpack"): Args: value: A rank `R > 0` `Tensor` to be unpacked. - num: An `int`. The first dimension of value. Automatically inferred if - `None` (the default). + num: An `int`. The length of the dimension `axis`. Automatically inferred + if `None` (the default). + axis: An `int`. The axis to unpack along. Defaults to the first + dimension. Supports negative indexes. name: A name for the operation (optional). Returns: @@ -606,14 +622,19 @@ def unpack(value, num=None, name="unpack"): Raises: ValueError: If `num` is unspecified and cannot be inferred. + ValueError: If `axis` is out of the range [-R, R). """ if num is None: value = ops.convert_to_tensor(value) - shape = value.get_shape() - num = shape[0].value - if num is None: - raise ValueError("Cannot infer num from shape %s" % shape) - return gen_array_ops._unpack(value, num=num, name=name) + value_shape = value.get_shape() + if value_shape.ndims is not None: + if axis < -value_shape.ndims or axis >= value_shape.ndims: + raise ValueError("axis = %d not in [%d, %d)" % + (axis, -value_shape.ndims, value_shape.ndims)) + num = value_shape[axis].value + if num is None: + raise ValueError("Cannot infer num from shape %s" % value_shape) + return gen_array_ops._unpack(value, num=num, axis=axis, name=name) def concat(concat_dim, values, name="concat"): @@ -679,15 +700,26 @@ def concat(concat_dim, values, name="concat"): @ops.RegisterShape("Pack") def _PackShape(op): input_shape = op.inputs[0].get_shape() + if input_shape.ndims is None: + return [tensor_shape.unknown_shape()] + for inp in op.inputs[1:]: input_shape = input_shape.merge_with(inp.get_shape()) - return [tensor_shape.TensorShape([len(op.inputs)]).concatenate(input_shape)] + + input_shape = input_shape.as_list() + input_shape.insert(op.get_attr("axis"), len(op.inputs)) + return [tensor_shape.TensorShape(input_shape)] @ops.RegisterShape("Unpack") def _UnpackShape(op): input_shape = op.inputs[0].get_shape() - return [input_shape[1:]] * op.get_attr("num") + if input_shape.ndims is None: + return [tensor_shape.unknown_shape()] * op.get_attr("num") + + input_shape = input_shape.as_list() + del input_shape[op.get_attr("axis")] + return [tensor_shape.TensorShape(input_shape)] * op.get_attr("num") @ops.RegisterShape("Concat") |