aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/unpack_op_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-23 15:00:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-23 16:03:23 -0700
commiteff93149a6dc8e6826898fd9f9c28c81e21c9836 (patch)
tree6836053bf76e1d0de8b3c3a2619bd826122ae7b4 /tensorflow/python/kernel_tests/unpack_op_test.py
parent6f0dc9000e250f7aabf9bcdf7b2c5a33352b5b9b (diff)
Add axis option to tf.unpack and tf.pack.
Change: 125729413
Diffstat (limited to 'tensorflow/python/kernel_tests/unpack_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/unpack_op_test.py68
1 files changed, 66 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py
index 7cc6d31efb..0cb701db82 100644
--- a/tensorflow/python/kernel_tests/unpack_op_test.py
+++ b/tensorflow/python/kernel_tests/unpack_op_test.py
@@ -23,6 +23,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+def np_split_sqeeze(array, axis):
+ axis_len = array.shape[axis]
+ return [
+ np.squeeze(arr, axis=(axis,))
+ for arr in np.split(array, axis_len, axis=axis)
+ ]
+
+
class UnpackOpTest(tf.test.TestCase):
def testSimple(self):
@@ -40,7 +48,7 @@ class UnpackOpTest(tf.test.TestCase):
cs = [c.eval() for c in cs]
self.assertAllEqual(cs, data)
- def testGradients(self):
+ def testGradientsAxis0(self):
for use_gpu in False, True:
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
data = np.random.randn(*shape)
@@ -52,6 +60,19 @@ class UnpackOpTest(tf.test.TestCase):
err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i])
self.assertLess(err, 1e-6)
+ def testGradientsAxis1(self):
+ for use_gpu in False, True:
+ for shape in (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ out_shape = list(shape)
+ del out_shape[1]
+ for i in xrange(shape[1]):
+ with self.test_session(use_gpu=use_gpu):
+ x = tf.constant(data)
+ cs = tf.unpack(x, num=shape[1], axis=1)
+ err = tf.test.compute_gradient_error(x, shape, cs[i], out_shape)
+ self.assertLess(err, 1e-6)
+
def testInferNum(self):
with self.test_session():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
@@ -60,12 +81,55 @@ class UnpackOpTest(tf.test.TestCase):
self.assertEqual(type(cs), list)
self.assertEqual(len(cs), shape[0])
- def testCannotInferNum(self):
+ def testCannotInferNumFromUnknownShape(self):
x = tf.placeholder(np.float32)
with self.assertRaisesRegexp(
ValueError, r'Cannot infer num from shape <unknown>'):
tf.unpack(x)
+ def testUnknownShapeOkWithNum(self):
+ x = tf.placeholder(np.float32)
+ tf.unpack(x, num=2)
+
+ def testCannotInferNumFromNoneShape(self):
+ x = tf.placeholder(np.float32, shape=(None,))
+ with self.assertRaisesRegexp(ValueError,
+ r'Cannot infer num from shape \(\?,\)'):
+ tf.unpack(x)
+
+ def testAgainstNumpy(self):
+ # For 1 to 5 dimensions.
+ for i in range(1, 6):
+ a = np.random.random(np.random.permutation(i) + 1)
+
+ # For all the possible axis to split it, including negative indices.
+ for j in range(-i, i):
+ expected = np_split_sqeeze(a, j)
+
+ with self.test_session() as sess:
+ actual = sess.run(tf.unpack(a, axis=j))
+
+ self.assertAllEqual(expected, actual)
+
+ def testAxis0Default(self):
+ with self.test_session() as sess:
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+
+ unpacked = sess.run(tf.unpack(a))
+
+ self.assertEqual(len(unpacked), 2)
+ self.assertAllEqual(unpacked[0], [1, 2, 3])
+ self.assertAllEqual(unpacked[1], [4, 5, 6])
+
+ def testAxisOutOfRange(self):
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+ with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'):
+ tf.unpack(a, axis=2)
+
+ def testAxisOutOfNegativeRange(self):
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+ with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'):
+ tf.unpack(a, axis=-3)
if __name__ == '__main__':
tf.test.main()