aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-23 15:00:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-23 16:03:23 -0700
commiteff93149a6dc8e6826898fd9f9c28c81e21c9836 (patch)
tree6836053bf76e1d0de8b3c3a2619bd826122ae7b4 /tensorflow/python
parent6f0dc9000e250f7aabf9bcdf7b2c5a33352b5b9b (diff)
Add axis option to tf.unpack and tf.pack.
Change: 125729413
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/pack_op_test.py57
-rw-r--r--tensorflow/python/kernel_tests/unpack_op_test.py68
-rw-r--r--tensorflow/python/ops/array_grad.py6
-rw-r--r--tensorflow/python/ops/array_ops.py72
4 files changed, 177 insertions, 26 deletions
diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py
index 349def7181..5d7055824d 100644
--- a/tensorflow/python/kernel_tests/pack_op_test.py
+++ b/tensorflow/python/kernel_tests/pack_op_test.py
@@ -22,6 +22,14 @@ import numpy as np
import tensorflow as tf
+def np_split_sqeeze(array, axis):
+ axis_len = array.shape[axis]
+ return [
+ np.squeeze(arr, axis=(axis,))
+ for arr in np.split(array, axis_len, axis=axis)
+ ]
+
+
class PackOpTest(tf.test.TestCase):
def testSimple(self):
@@ -61,7 +69,7 @@ class PackOpTest(tf.test.TestCase):
b = tf.reshape(a, tf.pack([2, 3]))
self.assertAllEqual(b.get_shape(), [2, 3])
- def testGradients(self):
+ def testGradientsAxis0(self):
np.random.seed(7)
for use_gpu in False, True:
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
@@ -74,6 +82,21 @@ class PackOpTest(tf.test.TestCase):
err = tf.test.compute_gradient_error(xs, shapes, c, shape)
self.assertLess(err, 1e-6)
+ def testGradientsAxis1(self):
+ np.random.seed(7)
+ for use_gpu in False, True:
+ for shape in (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ shapes = [shape[1:]] * shape[0]
+ out_shape = list(shape[1:])
+ out_shape.insert(1, shape[0])
+ with self.test_session(use_gpu=use_gpu):
+ # TODO(irving): Remove list() once we handle maps correctly
+ xs = list(map(tf.constant, data))
+ c = tf.pack(xs, axis=1)
+ err = tf.test.compute_gradient_error(xs, shapes, c, out_shape)
+ self.assertLess(err, 1e-6)
+
def testZeroSize(self):
# Verify that pack doesn't crash for zero size inputs
for use_gpu in False, True:
@@ -83,6 +106,38 @@ class PackOpTest(tf.test.TestCase):
p = tf.pack(list(x)).eval()
self.assertAllEqual(p, x)
+ def testAxis0Default(self):
+ with self.test_session():
+ t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])]
+
+ packed = tf.pack(t).eval()
+
+ self.assertAllEqual(packed, np.array([[1, 2, 3], [4, 5, 6]]))
+
+ def testAgainstNumpy(self):
+ # For 1 to 5 dimensions.
+ for i in range(1, 6):
+ expected = np.random.random(np.random.permutation(i) + 1)
+
+ # For all the possible axis to split it, including negative indices.
+ for j in range(-i, i):
+ test_arrays = np_split_sqeeze(expected, j)
+
+ with self.test_session():
+ actual = tf.pack(test_arrays, axis=j).eval()
+
+ self.assertNDArrayNear(expected, actual, 1e-6)
+
+ def testDimOutOfRange(self):
+ t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])]
+ with self.assertRaisesRegexp(ValueError, r"axis = 2 not in \[-2, 2\)"):
+ tf.unpack(t, axis=2)
+
+ def testDimOutOfNegativeRange(self):
+ t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])]
+ with self.assertRaisesRegexp(ValueError, r"axis = -3 not in \[-2, 2\)"):
+ tf.unpack(t, axis=-3)
+
class AutomaticPackingTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py
index 7cc6d31efb..0cb701db82 100644
--- a/tensorflow/python/kernel_tests/unpack_op_test.py
+++ b/tensorflow/python/kernel_tests/unpack_op_test.py
@@ -23,6 +23,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+def np_split_sqeeze(array, axis):
+ axis_len = array.shape[axis]
+ return [
+ np.squeeze(arr, axis=(axis,))
+ for arr in np.split(array, axis_len, axis=axis)
+ ]
+
+
class UnpackOpTest(tf.test.TestCase):
def testSimple(self):
@@ -40,7 +48,7 @@ class UnpackOpTest(tf.test.TestCase):
cs = [c.eval() for c in cs]
self.assertAllEqual(cs, data)
- def testGradients(self):
+ def testGradientsAxis0(self):
for use_gpu in False, True:
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
data = np.random.randn(*shape)
@@ -52,6 +60,19 @@ class UnpackOpTest(tf.test.TestCase):
err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i])
self.assertLess(err, 1e-6)
+ def testGradientsAxis1(self):
+ for use_gpu in False, True:
+ for shape in (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ out_shape = list(shape)
+ del out_shape[1]
+ for i in xrange(shape[1]):
+ with self.test_session(use_gpu=use_gpu):
+ x = tf.constant(data)
+ cs = tf.unpack(x, num=shape[1], axis=1)
+ err = tf.test.compute_gradient_error(x, shape, cs[i], out_shape)
+ self.assertLess(err, 1e-6)
+
def testInferNum(self):
with self.test_session():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
@@ -60,12 +81,55 @@ class UnpackOpTest(tf.test.TestCase):
self.assertEqual(type(cs), list)
self.assertEqual(len(cs), shape[0])
- def testCannotInferNum(self):
+ def testCannotInferNumFromUnknownShape(self):
x = tf.placeholder(np.float32)
with self.assertRaisesRegexp(
ValueError, r'Cannot infer num from shape <unknown>'):
tf.unpack(x)
+ def testUnknownShapeOkWithNum(self):
+ x = tf.placeholder(np.float32)
+ tf.unpack(x, num=2)
+
+ def testCannotInferNumFromNoneShape(self):
+ x = tf.placeholder(np.float32, shape=(None,))
+ with self.assertRaisesRegexp(ValueError,
+ r'Cannot infer num from shape \(\?,\)'):
+ tf.unpack(x)
+
+ def testAgainstNumpy(self):
+ # For 1 to 5 dimensions.
+ for i in range(1, 6):
+ a = np.random.random(np.random.permutation(i) + 1)
+
+ # For all the possible axis to split it, including negative indices.
+ for j in range(-i, i):
+ expected = np_split_sqeeze(a, j)
+
+ with self.test_session() as sess:
+ actual = sess.run(tf.unpack(a, axis=j))
+
+ self.assertAllEqual(expected, actual)
+
+ def testAxis0Default(self):
+ with self.test_session() as sess:
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+
+ unpacked = sess.run(tf.unpack(a))
+
+ self.assertEqual(len(unpacked), 2)
+ self.assertAllEqual(unpacked[0], [1, 2, 3])
+ self.assertAllEqual(unpacked[1], [4, 5, 6])
+
+ def testAxisOutOfRange(self):
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+ with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'):
+ tf.unpack(a, axis=2)
+
+ def testAxisOutOfNegativeRange(self):
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+ with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'):
+ tf.unpack(a, axis=-3)
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 0e85aaf80f..5ff7cb95f4 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -30,13 +30,13 @@ from tensorflow.python.ops import math_ops
@ops.RegisterGradient("Pack")
def _PackGrad(op, grad):
"""Gradient for pack op."""
- return array_ops.unpack(grad, num=op.get_attr("N"))
+ return array_ops.unpack(grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
@ops.RegisterGradient("Unpack")
-def _UnpackGrad(_, *grads):
+def _UnpackGrad(op, *grads):
"""Gradient for unpack op."""
- return array_ops.pack(grads)
+ return array_ops.pack(grads, axis=op.get_attr("axis"))
@ops.RegisterGradient("Concat")
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index f7c1f22ffd..53e9216380 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -467,7 +467,7 @@ def strided_slice(input_,
ops.Tensor._override_operator("__getitem__", _SliceHelper)
-def pack(values, name="pack"):
+def pack(values, axis=0, name="pack"):
"""Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor.
Packs tensors in `values` into a tensor with rank one higher than each tensor
@@ -480,17 +480,31 @@ def pack(values, name="pack"):
Args:
values: A list of `Tensor` objects with the same shape and type.
+ axis: An `int`. The axis to pack along. Defaults to the first dimension.
+ Supports negative indexes.
name: A name for this operation (optional).
Returns:
output: A packed `Tensor` with the same type as `values`.
+
+ Raises:
+ ValueError: If `axis` is out of the range [-(R+1), R+1).
"""
- try:
- # If the input is a constant list, it can just be converted to a constant op
- return ops.convert_to_tensor(values, name=name)
- except (TypeError, ValueError):
- # Input list contains non-constant tensors
- return gen_array_ops._pack(values, name=name)
+ if axis == 0:
+ try:
+ # If the input is a constant list, it can be converted to a constant op
+ return ops.convert_to_tensor(values, name=name)
+ except (TypeError, ValueError):
+ pass # Input list contains non-constant tensors
+
+ value_shape = ops.convert_to_tensor(values[0], name=name).get_shape()
+ if value_shape.ndims is not None:
+ expanded_num_dims = value_shape.ndims + 1
+ if axis < -expanded_num_dims or axis >= expanded_num_dims:
+ raise ValueError("axis = %d not in [%d, %d)" %
+ (axis, -expanded_num_dims, expanded_num_dims))
+
+ return gen_array_ops._pack(values, axis=axis, name=name)
# pylint: disable=invalid-name
@@ -581,12 +595,12 @@ ops.register_tensor_conversion_function(
(list, tuple), _autopacking_conversion_function, 99)
-def unpack(value, num=None, name="unpack"):
- """Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
+def unpack(value, num=None, axis=0, name="unpack"):
+ """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
- Unpacks `num` tensors from `value` along the first dimension.
+ Unpacks `num` tensors from `value` along the given dimension.
If `num` is not specified (the default), it is inferred from `value`'s shape.
- If `value.shape[0]` is not known, `ValueError` is raised.
+ If `value.shape[axis]` is not known, `ValueError` is raised.
The ith tensor in `output` is the slice `value[i, ...]`. Each tensor in
`output` has shape `value.shape[1:]`.
@@ -597,8 +611,10 @@ def unpack(value, num=None, name="unpack"):
Args:
value: A rank `R > 0` `Tensor` to be unpacked.
- num: An `int`. The first dimension of value. Automatically inferred if
- `None` (the default).
+ num: An `int`. The length of the dimension `axis`. Automatically inferred
+ if `None` (the default).
+ axis: An `int`. The axis to unpack along. Defaults to the first
+ dimension. Supports negative indexes.
name: A name for the operation (optional).
Returns:
@@ -606,14 +622,19 @@ def unpack(value, num=None, name="unpack"):
Raises:
ValueError: If `num` is unspecified and cannot be inferred.
+ ValueError: If `axis` is out of the range [-R, R).
"""
if num is None:
value = ops.convert_to_tensor(value)
- shape = value.get_shape()
- num = shape[0].value
- if num is None:
- raise ValueError("Cannot infer num from shape %s" % shape)
- return gen_array_ops._unpack(value, num=num, name=name)
+ value_shape = value.get_shape()
+ if value_shape.ndims is not None:
+ if axis < -value_shape.ndims or axis >= value_shape.ndims:
+ raise ValueError("axis = %d not in [%d, %d)" %
+ (axis, -value_shape.ndims, value_shape.ndims))
+ num = value_shape[axis].value
+ if num is None:
+ raise ValueError("Cannot infer num from shape %s" % value_shape)
+ return gen_array_ops._unpack(value, num=num, axis=axis, name=name)
def concat(concat_dim, values, name="concat"):
@@ -679,15 +700,26 @@ def concat(concat_dim, values, name="concat"):
@ops.RegisterShape("Pack")
def _PackShape(op):
input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims is None:
+ return [tensor_shape.unknown_shape()]
+
for inp in op.inputs[1:]:
input_shape = input_shape.merge_with(inp.get_shape())
- return [tensor_shape.TensorShape([len(op.inputs)]).concatenate(input_shape)]
+
+ input_shape = input_shape.as_list()
+ input_shape.insert(op.get_attr("axis"), len(op.inputs))
+ return [tensor_shape.TensorShape(input_shape)]
@ops.RegisterShape("Unpack")
def _UnpackShape(op):
input_shape = op.inputs[0].get_shape()
- return [input_shape[1:]] * op.get_attr("num")
+ if input_shape.ndims is None:
+ return [tensor_shape.unknown_shape()] * op.get_attr("num")
+
+ input_shape = input_shape.as_list()
+ del input_shape[op.get_attr("axis")]
+ return [tensor_shape.TensorShape(input_shape)] * op.get_attr("num")
@ops.RegisterShape("Concat")